Compare commits
123 Commits
user/miche
...
feat/add_r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
374d4351fd | ||
|
|
4f3eaff2bd | ||
|
|
1f7ddc1d76 | ||
|
|
ce63cfdb25 | ||
|
|
974028bd28 | ||
|
|
a36ed39487 | ||
|
|
c37b1d45b6 | ||
|
|
d6f1359e69 | ||
|
|
2357d4aceb | ||
|
|
f994febca4 | ||
|
|
12f52632ed | ||
|
|
8a64d8268b | ||
|
|
d6ccdc222c | ||
|
|
9bd0788131 | ||
|
|
1ae62c28f7 | ||
|
|
baf6e66c3d | ||
|
|
a065bd61ae | ||
|
|
84565c7c2e | ||
|
|
05b54733da | ||
|
|
513b008bcc | ||
|
|
32fffd4bbb | ||
|
|
03c7cf8a63 | ||
|
|
074f0ac8fe | ||
|
|
25c63ccf63 | ||
|
|
5dc3c74e64 | ||
|
|
5e9473806c | ||
|
|
4214b01703 | ||
|
|
b974e5541f | ||
|
|
10706ed753 | ||
|
|
fd64dc84ae | ||
|
|
0b8205a8a0 | ||
|
|
57ae509823 | ||
|
|
5d24ce3160 | ||
|
|
d694ea1d38 | ||
|
|
a00936686f | ||
|
|
2feb5edc65 | ||
|
|
b80e55ca44 | ||
|
|
e8ce388109 | ||
|
|
a4c1da25de | ||
|
|
a003e7c081 | ||
|
|
06988b2135 | ||
|
|
7ed7570b17 | ||
|
|
e2d13ba7e4 | ||
|
|
f6c1049474 | ||
|
|
2b24feb604 | ||
|
|
a27411022d | ||
|
|
3827974b58 | ||
|
|
b299cfea8a | ||
|
|
a13e49073c | ||
|
|
2c7e0f17b6 | ||
|
|
418866007e | ||
|
|
5ab418dbeb | ||
|
|
95f61ee9d4 | ||
|
|
ac89c8d226 | ||
|
|
d75d904e43 | ||
|
|
ea4d8d990c | ||
|
|
c93cbb8311 | ||
|
|
c0137e89b9 | ||
|
|
3111ba78ad | ||
|
|
3d3a176940 | ||
|
|
212c6095a2 | ||
|
|
bf6f89a5b5 | ||
|
|
48469ec674 | ||
|
|
c7dfd32b43 | ||
|
|
8861546ad8 | ||
|
|
9c1a893ee3 | ||
|
|
e81c36cf74 | ||
|
|
ed83cbd4f2 | ||
|
|
2a33b9ad87 | ||
|
|
6e85aa13ec | ||
|
|
af05a1725c | ||
|
|
800c4a847f | ||
|
|
bba8c4c0d4 | ||
|
|
68b369e321 | ||
|
|
8d60ac3ffc | ||
|
|
731fb6ebaf | ||
|
|
13e124302f | ||
|
|
59bdd29106 | ||
|
|
659ec4434d | ||
|
|
124829104b | ||
|
|
21cd2940a9 | ||
|
|
da265ca920 | ||
|
|
a1809ad3de | ||
|
|
8699a28be0 | ||
|
|
65db5afe1c | ||
|
|
75d5fa4604 | ||
|
|
e64fad2224 | ||
|
|
eecf32e77a | ||
|
|
3354d919fc | ||
|
|
aca464ca72 | ||
|
|
fe483b1d0d | ||
|
|
ddeade077e | ||
|
|
c4c2ce04e7 | ||
|
|
2cb0bf5d41 | ||
|
|
b86a2c0b47 | ||
|
|
c574eb4984 | ||
|
|
1e49cc4d60 | ||
|
|
e71095960f | ||
|
|
90e099b39f | ||
|
|
334deb985d | ||
|
|
8548a87bd4 | ||
|
|
638d411cd3 | ||
|
|
dd974529cf | ||
|
|
43e079f73e | ||
|
|
6674e36824 | ||
|
|
ae9605f03c | ||
|
|
3c0a209f9f | ||
|
|
1ee1acf8ad | ||
|
|
c4d912a241 | ||
|
|
4323bdce22 | ||
|
|
5daa45436d | ||
|
|
4def6d6ac2 | ||
|
|
d8560b8d5f | ||
|
|
380b836eee | ||
|
|
eec6796cb8 | ||
|
|
25a8597680 | ||
|
|
b8b368310c | ||
|
|
5097cd900e | ||
|
|
bc16e1b497 | ||
|
|
8f821ecad0 | ||
|
|
4519016e67 | ||
|
|
59e2757434 | ||
|
|
73b64c3089 |
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Misc
|
||||
.git
|
||||
tmp
|
||||
@@ -59,7 +73,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
!tests/data
|
||||
!tests/artifacts
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
14
.gitattributes
vendored
14
.gitattributes
vendored
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
*.memmap filter=lfs diff=lfs merge=lfs -text
|
||||
*.stl filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
14
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
14
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve LeRobot
|
||||
body:
|
||||
|
||||
26
.github/workflows/build-docker-images.yml
vendored
26
.github/workflows/build-docker-images.yml
vendored
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
|
||||
name: Builds
|
||||
@@ -8,6 +22,8 @@ on:
|
||||
schedule:
|
||||
- cron: "0 1 * * *"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -25,11 +41,14 @@ jobs:
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
@@ -60,11 +79,14 @@ jobs:
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
@@ -89,9 +111,13 @@ jobs:
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
16
.github/workflows/nightly-tests.yml
vendored
16
.github/workflows/nightly-tests.yml
vendored
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
|
||||
name: Nightly
|
||||
@@ -7,6 +21,8 @@ on:
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
|
||||
permissions: {}
|
||||
|
||||
# env:
|
||||
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
||||
jobs:
|
||||
|
||||
74
.github/workflows/quality.yml
vendored
74
.github/workflows/quality.yml
vendored
@@ -1,15 +1,29 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: Quality
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -19,7 +33,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
@@ -30,55 +46,27 @@ jobs:
|
||||
id: get-ruff-version
|
||||
run: |
|
||||
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
||||
echo "RUFF_VERSION=${RUFF_VERSION}" >> $GITHUB_ENV
|
||||
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Install Ruff
|
||||
run: python -m pip install "ruff==${{ env.RUFF_VERSION }}"
|
||||
env:
|
||||
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
|
||||
run: python -m pip install "ruff==${RUFF_VERSION}"
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check
|
||||
run: ruff check --output-format=github
|
||||
|
||||
- name: Ruff format
|
||||
run: ruff format --diff
|
||||
|
||||
|
||||
poetry_check:
|
||||
name: Poetry check
|
||||
typos:
|
||||
name: Typos
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check
|
||||
|
||||
|
||||
poetry_relax:
|
||||
name: Poetry relax
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
|
||||
- name: Install poetry-relax
|
||||
run: poetry self add poetry-relax
|
||||
|
||||
- name: Poetry relax
|
||||
id: poetry_relax
|
||||
run: |
|
||||
output=$(poetry relax --check 2>&1)
|
||||
if echo "$output" | grep -q "Proposing updates"; then
|
||||
echo "$output"
|
||||
echo ""
|
||||
echo "Some dependencies have caret '^' version requirement added by poetry by default."
|
||||
echo "Please replace them with '>='. You can do this by hand or use poetry-relax to do this."
|
||||
exit 1
|
||||
else
|
||||
echo "$output"
|
||||
fi
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.29.10
|
||||
|
||||
31
.github/workflows/test-docker-build.yml
vendored
31
.github/workflows/test-docker-build.yml
vendored
@@ -1,15 +1,29 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
|
||||
name: Test Dockerfiles
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
# Run only when DockerFile files are modified
|
||||
- "docker/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -22,6 +36,8 @@ jobs:
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
@@ -30,21 +46,18 @@ jobs:
|
||||
files: docker/**
|
||||
json: "true"
|
||||
|
||||
- name: Run step if only the files listed above change
|
||||
- name: Run step if only the files listed above change # zizmor: ignore[template-injection]
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: set-matrix
|
||||
env:
|
||||
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
|
||||
run: |
|
||||
echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT
|
||||
|
||||
|
||||
build_modified_dockerfiles:
|
||||
name: Build modified Docker images
|
||||
needs: get_changed_files
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
||||
if: needs.get_changed_files.outputs.matrix != ''
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -52,9 +65,13 @@ jobs:
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
|
||||
131
.github/workflows/test.yml
vendored
131
.github/workflows/test.yml
vendored
@@ -1,15 +1,28 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "lerobot/**"
|
||||
- "tests/**"
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "pyproject.toml"
|
||||
- ".pre-commit-config.yaml"
|
||||
- "Makefile"
|
||||
- ".cache/**"
|
||||
push:
|
||||
@@ -20,10 +33,16 @@ on:
|
||||
- "tests/**"
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "pyproject.toml"
|
||||
- ".pre-commit-config.yaml"
|
||||
- "Makefile"
|
||||
- ".cache/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
UV_VERSION: "0.6.0"
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
name: Pytest
|
||||
@@ -34,6 +53,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
@@ -41,25 +61,19 @@ jobs:
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
cache: "poetry"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --all-extras
|
||||
- name: Install lerobot (all extras)
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest tests -v --cov=./lerobot --durations=0 \
|
||||
uv run pytest tests -v --cov=./lerobot --durations=0 \
|
||||
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
@@ -74,66 +88,63 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --extras "test"
|
||||
- name: Install lerobot
|
||||
run: uv sync --extra "test"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest tests -v --cov=./lerobot --durations=0 \
|
||||
uv run pytest tests -v --cov=./lerobot --durations=0 \
|
||||
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
&& rm -rf tests/outputs outputs
|
||||
|
||||
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
|
||||
# end-to-end:
|
||||
# name: End-to-end
|
||||
# runs-on: ubuntu-latest
|
||||
# env:
|
||||
# MUJOCO_GL: egl
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# lfs: true # Ensure LFS files are pulled
|
||||
end-to-end:
|
||||
name: End-to-end
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
# - name: Install apt dependencies
|
||||
# # portaudio19-dev is needed to install pyaudio
|
||||
# run: |
|
||||
# sudo apt-get update && \
|
||||
# sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
- name: Install apt dependencies
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
run: |
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
|
||||
# - name: Install poetry
|
||||
# run: |
|
||||
# pipx install poetry && poetry config virtualenvs.in-project true
|
||||
# echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
|
||||
# - name: Set up Python 3.10
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: "3.10"
|
||||
# cache: "poetry"
|
||||
- name: Install lerobot (all extras)
|
||||
run: |
|
||||
uv venv
|
||||
uv sync --all-extras
|
||||
|
||||
# - name: Install poetry dependencies
|
||||
# run: |
|
||||
# poetry install --all-extras
|
||||
- name: venv
|
||||
run: |
|
||||
echo "PYTHON_PATH=${{ github.workspace }}/.venv/bin/python" >> $GITHUB_ENV
|
||||
|
||||
# - name: Test end-to-end
|
||||
# run: |
|
||||
# make test-end-to-end \
|
||||
# && rm -rf outputs
|
||||
- name: Test end-to-end
|
||||
run: |
|
||||
make test-end-to-end \
|
||||
&& rm -rf outputs
|
||||
|
||||
19
.github/workflows/trufflehog.yml
vendored
19
.github/workflows/trufflehog.yml
vendored
@@ -1,10 +1,23 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
on:
|
||||
push:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
@@ -14,6 +27,8 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
with:
|
||||
|
||||
20
.gitignore
vendored
20
.gitignore
vendored
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Logging
|
||||
logs
|
||||
tmp
|
||||
@@ -49,6 +63,10 @@ share/python-wheels/
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# uv/poetry lock files
|
||||
poetry.lock
|
||||
uv.lock
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
@@ -60,7 +78,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
!tests/data
|
||||
!tests/artifacts
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
@@ -1,7 +1,29 @@
|
||||
exclude: ^(tests/data)
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
exclude: "tests/artifacts/.*\\.safetensors$"
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
##### Meta #####
|
||||
- repo: meta
|
||||
hooks:
|
||||
- id: check-useless-excludes
|
||||
- id: check-hooks-apply
|
||||
|
||||
|
||||
##### Style / Misc. #####
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
@@ -13,26 +35,40 @@ repos:
|
||||
- id: check-toml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.30.2
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.19.0
|
||||
rev: v3.19.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.2
|
||||
rev: v0.9.10
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/python-poetry/poetry
|
||||
rev: 1.8.0
|
||||
hooks:
|
||||
- id: poetry-check
|
||||
- id: poetry-lock
|
||||
args:
|
||||
- "--check"
|
||||
- "--no-update"
|
||||
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.21.2
|
||||
rev: v8.24.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.4.1
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.3
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: ["-c", "pyproject.toml"]
|
||||
additional_dependencies: ["bandit[toml]"]
|
||||
|
||||
@@ -129,38 +129,71 @@ Follow these steps to start contributing:
|
||||
|
||||
🚨 **Do not** work on the `main` branch.
|
||||
|
||||
4. for development, we use `poetry` instead of just `pip` to easily track our dependencies.
|
||||
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
|
||||
Set up a development environment with conda or miniconda:
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
If you're using `uv`, it can manage python versions so you can instead do:
|
||||
```bash
|
||||
poetry install --sync --extras "dev test"
|
||||
uv venv --python 3.10 && source .venv/bin/activate
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry sync --extras "dev test"
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv sync --extra dev --extra test
|
||||
```
|
||||
|
||||
You can also install the project with all its dependencies (including environments):
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry install --sync --all-extras
|
||||
poetry sync --all-extras
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing.
|
||||
|
||||
Whichever command you chose to install the project (e.g. `poetry install --sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
|
||||
The equivalent of `pip install some-package`, would just be:
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
using `uv`
|
||||
```bash
|
||||
poetry lock --no-update
|
||||
uv add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry lock
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv lock
|
||||
```
|
||||
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
@@ -195,7 +228,7 @@ Follow these steps to start contributing:
|
||||
git commit
|
||||
```
|
||||
|
||||
Note, if you already commited some changes that have a wrong formatting, you can use:
|
||||
Note, if you already committed some changes that have a wrong formatting, you can use:
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
@@ -258,7 +291,7 @@ sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
Pull artifacts if they're not in [tests/data](tests/data)
|
||||
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
|
||||
```bash
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
250
Makefile
250
Makefile
@@ -1,11 +1,25 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
.PHONY: tests
|
||||
|
||||
PYTHON_PATH := $(shell which python)
|
||||
|
||||
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python
|
||||
POETRY_CHECK := $(shell command -v poetry)
|
||||
ifneq ($(POETRY_CHECK),)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
# If uv is installed and a virtual environment exists, use it
|
||||
UV_CHECK := $(shell command -v uv)
|
||||
ifneq ($(UV_CHECK),)
|
||||
PYTHON_PATH := $(shell .venv/bin/python)
|
||||
endif
|
||||
|
||||
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
|
||||
@@ -20,171 +34,109 @@ build-gpu:
|
||||
|
||||
test-end-to-end:
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-resume
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
|
||||
|
||||
test-act-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
policy.dim_model=64 \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
--policy.type=act \
|
||||
--policy.dim_model=64 \
|
||||
--policy.n_action_steps=20 \
|
||||
--policy.chunk_size=20 \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=aloha \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=4 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_freq=2 \
|
||||
--save_checkpoint=true \
|
||||
--log_freq=1 \
|
||||
--wandb.enable=false \
|
||||
--output_dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-train-resume:
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
|
||||
test-act-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/act/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-act-ete-train-amp:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
policy.dim_model=64 \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act_amp/ \
|
||||
training.image_transforms.enable=true \
|
||||
use_amp=true
|
||||
|
||||
test-act-ete-eval-amp:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/act_amp/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
use_amp=true
|
||||
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=aloha \
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
test-diffusion-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=diffusion \
|
||||
policy.down_dims=\[64,128,256\] \
|
||||
policy.diffusion_step_embed_dim=32 \
|
||||
policy.num_inference_steps=10 \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
--policy.type=diffusion \
|
||||
--policy.down_dims='[64,128,256]' \
|
||||
--policy.diffusion_step_embed_dim=32 \
|
||||
--policy.num_inference_steps=10 \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=pusht \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=2 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2 \
|
||||
--log_freq=1 \
|
||||
--wandb.enable=false \
|
||||
--output_dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=pusht \
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
test-tdmpc-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
env=xarm \
|
||||
env.task=XarmLift-v0 \
|
||||
dataset_repo_id=lerobot/xarm_lift_medium \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-train-with-online:
|
||||
python lerobot/scripts/train.py \
|
||||
env=pusht \
|
||||
env.gym.obs_type=environment_state_agent_pos \
|
||||
policy=tdmpc_pusht_keypoints \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=10 \
|
||||
device=$(DEVICE) \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=20 \
|
||||
training.save_checkpoint=false \
|
||||
training.save_freq=10 \
|
||||
training.batch_size=2 \
|
||||
training.online_rollout_n_episodes=2 \
|
||||
training.online_rollout_batch_size=2 \
|
||||
training.online_steps_between_rollouts=10 \
|
||||
training.online_buffer_capacity=15 \
|
||||
eval.use_async_envs=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc_online/
|
||||
|
||||
--policy.type=tdmpc \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=xarm \
|
||||
--env.task=XarmLift-v0 \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/xarm_lift_medium \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=2 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2 \
|
||||
--log_freq=1 \
|
||||
--wandb.enable=false \
|
||||
--output_dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-default-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-act-pusht-tutorial:
|
||||
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
|
||||
python lerobot/scripts/train.py \
|
||||
policy=created_by_Makefile.yaml \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=$(DEVICE) \
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act_pusht/
|
||||
rm lerobot/configs/policy/created_by_Makefile.yaml
|
||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=xarm \
|
||||
--env.episode_length=5 \
|
||||
--env.task=XarmLift-v0 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
104
README.md
104
README.md
@@ -23,15 +23,24 @@
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">New robot in town: SO-100</a></p>
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">
|
||||
Build Your Own SO-100 Robot!</a></p>
|
||||
</h2>
|
||||
|
||||
<div align="center">
|
||||
<img src="media/so100/leader_follower.webp?raw=true" alt="SO-100 leader and follower arms" title="SO-100 leader and follower arms" width="50%">
|
||||
<p>We just added a new tutorial on how to build a more affordable robot, at the price of $110 per arm!</p>
|
||||
<p>Teach it new skills by showing it a few moves with just a laptop.</p>
|
||||
<p>Then watch your homemade robot act autonomously 🤯</p>
|
||||
<p>Follow the link to the <a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">full tutorial for SO-100</a>.</p>
|
||||
<img src="media/so100/leader_follower.webp?raw=true" alt="SO-100 leader and follower arms" title="SO-100 leader and follower arms" width="50%">
|
||||
|
||||
<p><strong>Meet the SO-100 – Just $110 per arm!</strong></p>
|
||||
<p>Train it in minutes with a few simple moves on your laptop.</p>
|
||||
<p>Then sit back and watch your creation act autonomously! 🤯</p>
|
||||
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">
|
||||
Get the full SO-100 tutorial here.</a></p>
|
||||
|
||||
<p>Want to take it to the next level? Make your SO-100 mobile by building LeKiwi!</p>
|
||||
<p>Check out the <a href="https://github.com/huggingface/lerobot/blob/main/examples/11_use_lekiwi.md">LeKiwi tutorial</a> and bring your robot to life on wheels.</p>
|
||||
|
||||
<img src="media/lekiwi/kiwi.webp?raw=true" alt="LeKiwi mobile robot" title="LeKiwi mobile robot" width="50%">
|
||||
</div>
|
||||
|
||||
<br/>
|
||||
@@ -122,10 +131,7 @@ wandb login
|
||||
├── examples # contains demonstration examples, start here to learn about LeRobot
|
||||
| └── advanced # contains even more examples for those who have mastered the basics
|
||||
├── lerobot
|
||||
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
||||
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
||||
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, xarm.yaml
|
||||
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
|
||||
| ├── configs # contains config classes with all options that you can override in the command line
|
||||
| ├── common # contains classes and utilities
|
||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
|
||||
| | ├── envs # various sim environments: aloha, pusht, xarm
|
||||
@@ -213,7 +219,7 @@ A `LeRobotDataset` is serialised using several widespread file formats for each
|
||||
- videos are stored in mp4 format to save space
|
||||
- metadata are stored in plain json/jsonl files
|
||||
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
@@ -222,87 +228,48 @@ Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrat
|
||||
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
-p lerobot/diffusion_pusht \
|
||||
eval.n_episodes=10 \
|
||||
eval.batch_size=10
|
||||
--policy.path=lerobot/diffusion_pusht \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
--eval.n_episodes=10 \
|
||||
--policy.use_amp=false \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/eval.py -p {OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
python lerobot/scripts/eval.py --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrate how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
|
||||
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
env=aloha \
|
||||
env.task=AlohaInsertion-v0 \
|
||||
dataset_repo_id=lerobot/aloha_sim_insertion_human \
|
||||
```
|
||||
|
||||
The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
|
||||
```bash
|
||||
hydra.run.dir=your/new/experiment/dir
|
||||
```
|
||||
|
||||
In the experiment directory there will be a folder called `checkpoints` which will have the following structure:
|
||||
|
||||
```bash
|
||||
checkpoints
|
||||
├── 000250 # checkpoint_dir for training step 250
|
||||
│ ├── pretrained_model # Hugging Face pretrained model dir
|
||||
│ │ ├── config.json # Hugging Face pretrained model config
|
||||
│ │ ├── config.yaml # consolidated Hydra config
|
||||
│ │ ├── model.safetensors # model weights
|
||||
│ │ └── README.md # Hugging Face model card
|
||||
│ └── training_state.pth # optimizer/scheduler/rng state and training step
|
||||
```
|
||||
|
||||
To resume training from a checkpoint, you can add these to the `train.py` python command:
|
||||
```bash
|
||||
hydra.run.dir=your/original/experiment/dir resume=true
|
||||
```
|
||||
|
||||
It will load the pretrained model, optimizer and scheduler states for training. For more information please see our tutorial on training resumption [here](https://github.com/huggingface/lerobot/blob/main/examples/5_resume_training.md).
|
||||
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
|
||||
|
||||
```bash
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](./examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
|
||||
|
||||

|
||||
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
#### Reproduce state-of-the-art (SOTA)
|
||||
|
||||
We have organized our configuration files (found under [`lerobot/configs`](./lerobot/configs)) such that they reproduce SOTA results from a given model variant in their respective original works. Simply running:
|
||||
|
||||
We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances.
|
||||
You can reproduce their training by loading the config from their run. Simply running:
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=diffusion env=pusht
|
||||
python lerobot/scripts/train.py --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
Pretrained policies, along with reproduction details, can be found under the "Models" section of https://huggingface.co/lerobot.
|
||||
|
||||
## Contribute
|
||||
|
||||
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Add a new dataset
|
||||
<!-- ### Add a new dataset
|
||||
|
||||
To add a dataset to the hub, you need to login using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
@@ -320,7 +287,7 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
|
||||
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
||||
|
||||
If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py).
|
||||
If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py). -->
|
||||
|
||||
|
||||
### Add a pretrained policy
|
||||
@@ -330,7 +297,7 @@ Once you have trained a policy you may upload it to the Hugging Face hub using a
|
||||
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
|
||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
|
||||
- `train_config.json`: A consolidated configuration containing all parameter userd for training. The policy configuration should match `config.json` exactly. Thisis useful for anyone who wants to evaluate your policy or for reproducibility.
|
||||
|
||||
To upload these to the hub, run the following:
|
||||
```bash
|
||||
@@ -417,3 +384,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
@@ -114,7 +114,7 @@ We tried to measure the most impactful parameters for both encoding and decoding
|
||||
|
||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
||||
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
|
||||
@@ -32,11 +32,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import (
|
||||
mean_squared_error,
|
||||
peak_signal_noise_ratio,
|
||||
structural_similarity,
|
||||
)
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
@@ -85,9 +81,7 @@ def get_directory_size(directory: Path) -> int:
|
||||
return total_size
|
||||
|
||||
|
||||
def load_original_frames(
|
||||
imgs_dir: Path, timestamps: list[float], fps: int
|
||||
) -> torch.Tensor:
|
||||
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
@@ -100,11 +94,7 @@ def load_original_frames(
|
||||
|
||||
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path,
|
||||
save_dir: Path,
|
||||
frames: torch.Tensor,
|
||||
timestamps: list[float],
|
||||
fps: int,
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
||||
return
|
||||
@@ -114,10 +104,7 @@ def save_decoded_frames(
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
||||
shutil.copyfile(
|
||||
imgs_dir / f"frame_{idx:06d}.png",
|
||||
save_dir / f"frame_{idx:06d}_original.png",
|
||||
)
|
||||
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
@@ -129,17 +116,11 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [
|
||||
key for key in hf_dataset.features if key.startswith("observation.image")
|
||||
]
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
tqdm(
|
||||
imgs_dataset,
|
||||
desc=f"saving {dataset.repo_id} first episode images",
|
||||
leave=False,
|
||||
)
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
@@ -148,9 +129,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
break
|
||||
|
||||
|
||||
def sample_timestamps(
|
||||
timestamps_mode: str, ep_num_images: int, fps: int
|
||||
) -> list[float]:
|
||||
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
|
||||
# Start at 5 to allow for 2_frames_4_space and 6_frames
|
||||
idx = random.randint(5, ep_num_images - 1)
|
||||
match timestamps_mode:
|
||||
@@ -175,9 +154,7 @@ def decode_video_frames(
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
|
||||
@@ -204,9 +181,7 @@ def benchmark_decoding(
|
||||
}
|
||||
|
||||
with time_benchmark:
|
||||
frames = decode_video_frames(
|
||||
video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend
|
||||
)
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
@@ -215,18 +190,12 @@ def benchmark_decoding(
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(
|
||||
mean_squared_error(original_frames_np[i], frames_np[i])
|
||||
)
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["psnr_values"].append(
|
||||
peak_signal_noise_ratio(
|
||||
original_frames_np[i], frames_np[i], data_range=1.0
|
||||
)
|
||||
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
|
||||
)
|
||||
result["ssim_values"].append(
|
||||
structural_similarity(
|
||||
original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0
|
||||
)
|
||||
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
|
||||
)
|
||||
|
||||
if save_frames and sample == 0:
|
||||
@@ -246,9 +215,7 @@ def benchmark_decoding(
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
|
||||
for future in tqdm(
|
||||
as_completed(futures), total=num_samples, desc="samples", leave=False
|
||||
):
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
load_times_images_ms.append(result["load_time_images_ms"])
|
||||
@@ -308,13 +275,9 @@ def benchmark_encoding_decoding(
|
||||
random.seed(seed)
|
||||
benchmark_table = []
|
||||
for timestamps_mode in tqdm(
|
||||
decoding_cfg["timestamps_modes"],
|
||||
desc="decodings (timestamps_modes)",
|
||||
leave=False,
|
||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||
):
|
||||
for backend in tqdm(
|
||||
decoding_cfg["backends"], desc="decodings (backends)", leave=False
|
||||
):
|
||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||
benchmark_row = benchmark_decoding(
|
||||
imgs_dir,
|
||||
video_path,
|
||||
@@ -392,23 +355,14 @@ def main(
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for key, values in tqdm(
|
||||
encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False
|
||||
):
|
||||
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
|
||||
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path(
|
||||
"_".join(str(value) for value in encoding_cfg.values())
|
||||
)
|
||||
video_path = (
|
||||
output_dir
|
||||
/ "videos"
|
||||
/ args_path
|
||||
/ f"{repo_id.replace('/', '_')}.mp4"
|
||||
)
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
@@ -434,9 +388,7 @@ def main(
|
||||
# Concatenate all results
|
||||
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
|
||||
concatenated_df = pd.concat(df_list, ignore_index=True)
|
||||
concatenated_path = (
|
||||
output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
)
|
||||
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
concatenated_df.to_csv(concatenated_path, header=True, index=False)
|
||||
|
||||
|
||||
|
||||
18
checkport.py
18
checkport.py
@@ -1,18 +0,0 @@
|
||||
import socket
|
||||
|
||||
|
||||
def check_port(host, port):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
s.connect((host, port))
|
||||
print(f"Connection successful to {host}:{port}!")
|
||||
except Exception as e:
|
||||
print(f"Connection failed to {host}:{port}: {e}")
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
host = "127.0.0.1" # or "localhost"
|
||||
port = 51350
|
||||
check_port(host, port)
|
||||
@@ -1,32 +1,29 @@
|
||||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
|
||||
# Configure environment variables
|
||||
ARG PYTHON_VERSION
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MUJOCO_GL="egl"
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
# Install dependencies and set up Python in a single layer
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||
&& python -m venv /opt/venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Execute in bash shell rather than python
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
sed gawk grep curl wget zip unzip \
|
||||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher \
|
||||
speech-dispatcher portaudio19-dev libgeos-dev \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
FROM huggingface/lerobot-gpu:latest
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libvulkan1 vulkan-tools \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[mani-skill]"
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
@@ -1,30 +1,24 @@
|
||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||
|
||||
# Configure image
|
||||
# Configure environment variables
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher \
|
||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MUJOCO_GL="egl"
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
# Install dependencies and set up Python in a single layer
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||
&& python -m venv /opt/venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
|
||||
|
||||
@@ -1,296 +0,0 @@
|
||||
# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
|
||||
|
||||
|
||||
## A. Source the parts
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## B. Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
# Linux:
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
# Mac M-series:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
# Mac Intel:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc` (*Mac*: `source ~/.bash_profile`) or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
*For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## C. Configure the motors
|
||||
|
||||
### 1. Find the USB ports associated to each arm
|
||||
|
||||
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm.
|
||||
|
||||
#### a. Run the script to find ports
|
||||
|
||||
Follow Step 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I), which illustrates the use of our scripts below.
|
||||
|
||||
To find the port for each bus servo adapter, run the utility script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
#### b. Example outputs
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
#### c. Troubleshooting
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### d. Update YAML file
|
||||
|
||||
Now that you have the ports, modify the *port* sections in `so100.yaml`
|
||||
|
||||
### 2. Configure the motors
|
||||
|
||||
#### a. Set IDs for all 12 motors
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
*Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).*
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
|
||||
#### b. Remove the gears of the 6 leader motors
|
||||
|
||||
Follow step 2 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=248). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
#### c. Add motor horn to all 12 motors
|
||||
Follow step 3 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=569). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## D. Assemble the arms
|
||||
|
||||
Follow step 4 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=610). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## E. Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
|
||||
#### a. Manual calibration of follower arm
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
#### b. Manual calibration of leader arm
|
||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## F. Teleoperate
|
||||
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--robot-overrides '~cameras' \
|
||||
--display-cameras 0
|
||||
```
|
||||
|
||||
|
||||
#### a. Teleop with displaying cameras
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml
|
||||
```
|
||||
|
||||
## G. Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
--tags so100 tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 2 \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
## H. Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/so100_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/so100_test
|
||||
```
|
||||
|
||||
## I. Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## J. Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/so100_test \
|
||||
policy=act_so100_real \
|
||||
env=so100_real \
|
||||
hydra.run.dir=outputs/train/act_so100_test \
|
||||
hydra.job.name=act_so100_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/so100_test`.
|
||||
2. We provided the policy with `policy=act_so100_real`. This loads configurations from [`lerobot/configs/policy/act_so100_real.yaml`](../lerobot/configs/policy/act_so100_real.yaml). Importantly, this policy uses 2 cameras as input `laptop`, `phone`.
|
||||
3. We provided an environment as argument with `env=so100_real`. This loads configurations from [`lerobot/configs/env/so100_real.yaml`](../lerobot/configs/env/so100_real.yaml).
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||
|
||||
## K. Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/eval_act_so100_test \
|
||||
--tags so100 tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 10 \
|
||||
-p outputs/train/act_so100_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`).
|
||||
|
||||
## L. More Information
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).
|
||||
@@ -1,94 +0,0 @@
|
||||
# Training a HIL-SERL Reward Classifier with LeRobot
|
||||
|
||||
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
|
||||
|
||||
---
|
||||
|
||||
## Training Script Overview
|
||||
|
||||
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
|
||||
|
||||
1. **Configuration Loading**
|
||||
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
|
||||
|
||||
2. **Dataset Initialization**
|
||||
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
|
||||
|
||||
3. **Classifier Initialization**
|
||||
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
|
||||
- A single probability (binary classification), or
|
||||
- Logits (multi-class classification).
|
||||
|
||||
4. **Training Loop Execution**
|
||||
The script performs:
|
||||
- Forward and backward passes,
|
||||
- Optimization steps,
|
||||
- Periodic logging, evaluation, and checkpoint saving.
|
||||
|
||||
---
|
||||
|
||||
## Configuring with Hydra
|
||||
|
||||
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
|
||||
|
||||
### Config File Setup
|
||||
|
||||
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
|
||||
|
||||
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
|
||||
|
||||
```yaml
|
||||
# Example: lerobot/configs/policy/reward_classifier.yaml
|
||||
dataset_repo_id: "my_dataset_repo_id"
|
||||
## Typical logs and metrics
|
||||
```
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
[2024-11-29 18:26:36,999][root][INFO] -
|
||||
Epoch 5/5
|
||||
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
|
||||
```
|
||||
|
||||
### Metrics Tracking with Weights & Biases (WandB)
|
||||
|
||||
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
|
||||
|
||||
- **Training Metrics**:
|
||||
- `train/accuracy`
|
||||
- `train/loss`
|
||||
- `train/dataloading_s`
|
||||
- **Evaluation Metrics**:
|
||||
- `eval/accuracy`
|
||||
- `eval/loss`
|
||||
- `eval/eval_s`
|
||||
|
||||
#### Additional Features
|
||||
|
||||
You can also log sample predictions during evaluation. Each logged sample will include:
|
||||
|
||||
- The **input image**.
|
||||
- The **predicted label**.
|
||||
- The **true label**.
|
||||
- The **classifier's "confidence" (logits/probability)**.
|
||||
|
||||
These logs can be useful for diagnosing and debugging performance issues.
|
||||
|
||||
|
||||
#### Generate protobuf files
|
||||
|
||||
```bash
|
||||
python -m grpc_tools.protoc \
|
||||
-I lerobot/scripts/server \
|
||||
--python_out=lerobot/scripts/server \
|
||||
--grpc_python_out=lerobot/scripts/server \
|
||||
lerobot/scripts/server/hilserl.proto
|
||||
```
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||
@@ -18,10 +32,7 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
@@ -29,10 +40,7 @@ pprint(lerobot.available_datasets)
|
||||
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [
|
||||
info.id
|
||||
for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])
|
||||
]
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
pprint(repo_ids)
|
||||
|
||||
# Or simply explore them in your web browser directly at:
|
||||
@@ -47,9 +55,7 @@ ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(
|
||||
f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}"
|
||||
)
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
|
||||
@@ -1,6 +1,25 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||
|
||||
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
||||
```bash
|
||||
pip install -e ".[pusht]"`
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -10,7 +29,6 @@ import gymnasium as gym
|
||||
import imageio
|
||||
import numpy
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
@@ -18,27 +36,15 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# Select your device
|
||||
device = "cuda"
|
||||
|
||||
# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):
|
||||
pretrained_policy_path = "lerobot/diffusion_pusht"
|
||||
# OR a path to a local outputs/train folder.
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
|
||||
# Check if GPU is available
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
print("GPU is available. Device set to:", device)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(
|
||||
f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU."
|
||||
)
|
||||
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
||||
policy.diffusion.num_inference_steps = 10
|
||||
|
||||
policy.to(device)
|
||||
|
||||
# Initialize evaluation environment to render two observation types:
|
||||
# an image of the scene and state/position of the agent. The environment
|
||||
@@ -49,7 +55,17 @@ env = gym.make(
|
||||
max_episode_steps=300,
|
||||
)
|
||||
|
||||
# Reset the policy and environmens to prepare for rollout
|
||||
# We can verify that the shapes of the features expected by the policy match the ones from the observations
|
||||
# produced by the environment
|
||||
print(policy.config.input_features)
|
||||
print(env.observation_space)
|
||||
|
||||
# Similarly, we can check that the actions produced by the policy will match the actions expected by the
|
||||
# environment
|
||||
print(policy.config.output_features)
|
||||
print(env.action_space)
|
||||
|
||||
# Reset the policy and environments to prepare for rollout
|
||||
policy.reset()
|
||||
numpy_observation, info = env.reset(seed=42)
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||
|
||||
Once you have trained a model with this script, you can try to evaluate it on
|
||||
@@ -8,89 +22,99 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
# Create a directory to store the training checkpoint.
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Number of offline training steps (we'll only do offline training for this example.)
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
training_steps = 5000
|
||||
device = torch.device("cuda")
|
||||
log_freq = 250
|
||||
def main():
|
||||
# Create a directory to store the training checkpoint.
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set up the dataset.
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.4,
|
||||
],
|
||||
}
|
||||
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
# # Select your device
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Set up the the policy.
|
||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
|
||||
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
# Number of offline training steps (we'll only do offline training for this example.)
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
training_steps = 5000
|
||||
log_freq = 1
|
||||
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
|
||||
# creating the policy:
|
||||
# - input/output shapes: to properly size the policy
|
||||
# - dataset stats: for normalization and denormalization of input/outputs
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
# Create dataloader for offline training.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=True,
|
||||
)
|
||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
|
||||
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
|
||||
# Run training loop.
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
loss = output_dict["loss"]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# We can now instantiate our policy with this config and the dataset stats.
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
|
||||
# which can differ for inputs, outputs and rewards (if there are some).
|
||||
delta_timestamps = {
|
||||
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
|
||||
}
|
||||
|
||||
# Save a policy checkpoint.
|
||||
policy.save_pretrained(output_directory)
|
||||
# In this case with the standard configuration for Diffusion Policy, it is equivalent to this:
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# We can then instantiate the dataset with these delta_timestamps configuration.
|
||||
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
|
||||
# Then we create our optimizer and dataloader for offline training.
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Run training loop.
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save a policy checkpoint.
|
||||
policy.save_pretrained(output_directory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,193 +1,223 @@
|
||||
This tutorial will explain the training script, how to use it, and particularly the use of Hydra to configure everything needed for the training run.
|
||||
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
|
||||
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||
|
||||
|
||||
## The training script
|
||||
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
|
||||
|
||||
- Loads a Hydra configuration file for the following steps (more on Hydra in a moment).
|
||||
- Makes a simulation environment.
|
||||
- Makes a dataset corresponding to that simulation environment.
|
||||
- Makes a policy.
|
||||
- Initialize/load a configuration for the following steps using.
|
||||
- Instantiates a dataset.
|
||||
- (Optional) Instantiates a simulation environment corresponding to that dataset.
|
||||
- Instantiates a policy.
|
||||
- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing.
|
||||
|
||||
## Basics of how we use Hydra
|
||||
|
||||
Explaining the ins and outs of [Hydra](https://hydra.cc/docs/intro/) is beyond the scope of this document, but here we'll share the main points you need to know.
|
||||
|
||||
First, `lerobot/configs` has a directory structure like this:
|
||||
|
||||
```
|
||||
.
|
||||
├── default.yaml
|
||||
├── env
|
||||
│ ├── aloha.yaml
|
||||
│ ├── pusht.yaml
|
||||
│ └── xarm.yaml
|
||||
└── policy
|
||||
├── act.yaml
|
||||
├── diffusion.yaml
|
||||
└── tdmpc.yaml
|
||||
```
|
||||
|
||||
**_For brevity, in the rest of this document we'll drop the leading `lerobot/configs` path. So `default.yaml` really refers to `lerobot/configs/default.yaml`._**
|
||||
|
||||
When you run the training script with
|
||||
## Overview of the configuration system
|
||||
|
||||
In the training script, the main function `train` expects a `TrainPipelineConfig` object:
|
||||
```python
|
||||
python lerobot/scripts/train.py
|
||||
# train.py
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
```
|
||||
|
||||
Hydra is set up to read `default.yaml` (via the `@hydra.main` decorator). If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. At the top of `default.yaml`, is a `defaults` section which looks likes this:
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
|
||||
```yaml
|
||||
defaults:
|
||||
- _self_
|
||||
- env: pusht
|
||||
- policy: diffusion
|
||||
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
||||
|
||||
Let's have a look at a simplified example. Amongst other attributes, the training config has the following attributes:
|
||||
```python
|
||||
@dataclass
|
||||
class TrainPipelineConfig:
|
||||
dataset: DatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
```
|
||||
in which `DatasetConfig` for example is defined as such:
|
||||
```python
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
repo_id: str
|
||||
episodes: list[int] | None = None
|
||||
video_backend: str = "pyav"
|
||||
```
|
||||
|
||||
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
||||
This creates a hierarchical relationship where, for example assuming we have a `cfg` instance of `TrainPipelineConfig`, we can access the `repo_id` value with `cfg.dataset.repo_id`.
|
||||
From the command line, we can specify this value with using a very similar syntax `--dataset.repo_id=repo/id`.
|
||||
|
||||
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
|
||||
By default, every field takes its default value specified in the dataclass. If a field doesn't have a default value, it needs to be specified either from the command line or from a config file – which path is also given in the command line (more in this below). In the example above, the `dataset` field doesn't have a default value which means it must be specified.
|
||||
|
||||
Thanks to this `defaults` section in `default.yaml`, if you want to train Diffusion Policy with PushT, you really only need to run:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py
|
||||
```
|
||||
|
||||
However, you can be more explicit and launch the exact same Diffusion Policy training on PushT with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=diffusion env=pusht
|
||||
```
|
||||
|
||||
This way of overriding defaults via the CLI is especially useful when you want to change the policy and/or environment. For instance, you can train ACT on the default Aloha environment with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=aloha
|
||||
```
|
||||
|
||||
There are two things to note here:
|
||||
- Config overrides are passed as `param_name=param_value`.
|
||||
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`.
|
||||
|
||||
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
|
||||
|
||||
## Overriding configuration parameters in the CLI
|
||||
|
||||
Now let's say that we want to train on a different task in the Aloha environment. If you look in `env/aloha.yaml` you will see something like:
|
||||
|
||||
```yaml
|
||||
# lerobot/configs/env/aloha.yaml
|
||||
env:
|
||||
task: AlohaInsertion-v0
|
||||
```
|
||||
|
||||
And if you look in `policy/act.yaml` you will see something like:
|
||||
|
||||
```yaml
|
||||
# lerobot/configs/policy/act.yaml
|
||||
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
```
|
||||
|
||||
But our Aloha environment actually supports a cube transfer task as well. To train for this task, you could manually modify the two yaml configuration files respectively.
|
||||
|
||||
First, we'd need to switch to using the cube transfer task for the ALOHA environment.
|
||||
|
||||
```diff
|
||||
# lerobot/configs/env/aloha.yaml
|
||||
env:
|
||||
- task: AlohaInsertion-v0
|
||||
+ task: AlohaTransferCube-v0
|
||||
```
|
||||
|
||||
Then, we'd also need to switch to using the cube transfer dataset.
|
||||
|
||||
```diff
|
||||
# lerobot/configs/policy/act.yaml
|
||||
-dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
+dataset_repo_id: lerobot/aloha_sim_transfer_cube_human
|
||||
```
|
||||
|
||||
Then, you'd be able to run:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=aloha
|
||||
```
|
||||
|
||||
and you'd be training and evaluating on the cube transfer task.
|
||||
|
||||
An alternative approach to editing the yaml configuration files, would be to override the defaults via the command line:
|
||||
## Specifying values from the CLI
|
||||
|
||||
Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy.type=diffusion \
|
||||
--env.type=pusht
|
||||
```
|
||||
|
||||
There's something new here. Notice the `.` delimiter used to traverse the configuration hierarchy. _But be aware that the `defaults` section is an exception. As you saw above, we didn't need to write `defaults.policy=act` in the CLI. `policy=act` was enough._
|
||||
|
||||
Putting all that knowledge together, here's the command that was used to train https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human.
|
||||
Let's break this down:
|
||||
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py)
|
||||
|
||||
Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human \
|
||||
device=cuda
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0 \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
policy=act \
|
||||
training.eval_freq=10000 \
|
||||
training.log_freq=250 \
|
||||
training.offline_steps=100000 \
|
||||
training.save_model=true \
|
||||
training.save_freq=25000 \
|
||||
eval.n_episodes=50 \
|
||||
eval.batch_size=50 \
|
||||
wandb.enable=false \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
--output_dir=outputs/train/act_aloha_insertion
|
||||
```
|
||||
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
|
||||
|
||||
There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human`, which specifies where to save the training output.
|
||||
|
||||
## Using a configuration file not in `lerobot/configs`
|
||||
|
||||
Above we discusses the our training script is set up such that Hydra looks for `default.yaml` in `lerobot/configs`. But, if you have a configuration file elsewhere in your filesystem you may use:
|
||||
|
||||
We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
|
||||
Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config-dir PARENT/PATH --config-name FILE_NAME_WITHOUT_EXTENSION
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaTransferCube-v0 \
|
||||
--output_dir=outputs/train/act_aloha_transfer
|
||||
```
|
||||
|
||||
Note: here we use regular syntax for providing CLI arguments to a Python script, not Hydra's `param_name=param_value` syntax.
|
||||
## Loading from a config file
|
||||
|
||||
As a concrete example, this becomes particularly handy when you have a folder with training outputs, and would like to re-run the training. For example, say you previously ran the training script with one of the earlier commands and have `outputs/train/my_experiment/checkpoints/pretrained_model/config.yaml`. This `config.yaml` file will have the full set of configuration parameters within it. To run the training with the same configuration again, do:
|
||||
Now, let's assume that we want to reproduce the run just above. That run has produced a `train_config.json` file in its checkpoints, which serializes the `TrainPipelineConfig` instance it used:
|
||||
```json
|
||||
{
|
||||
"dataset": {
|
||||
"repo_id": "lerobot/aloha_sim_transfer_cube_human",
|
||||
"episodes": null,
|
||||
...
|
||||
},
|
||||
"env": {
|
||||
"type": "aloha",
|
||||
"task": "AlohaTransferCube-v0",
|
||||
"fps": 50,
|
||||
...
|
||||
},
|
||||
"policy": {
|
||||
"type": "act",
|
||||
"n_obs_steps": 1,
|
||||
...
|
||||
},
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
We can then simply load the config values from this file using:
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpoints/last/pretrained_model --config-name config
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
```
|
||||
`--config_path` is also a special argument which allows to initialize the config from a local config file. It can point to a directory that contains `train_config.json` or to the config file itself directly.
|
||||
|
||||
Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
--policy.n_action_steps=80
|
||||
```
|
||||
> Note: While `--output_dir` is not required in general, in this case we need to specify it since it will otherwise take the value from the `train_config.json` (which is `outputs/train/act_aloha_transfer`). In order to prevent accidental deletion of previous run checkpoints, we raise an error if you're trying to write in an existing directory. This is not the case when resuming a run, which is what you'll learn next.
|
||||
|
||||
`--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running:
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)
|
||||
|
||||
|
||||
## Resume training
|
||||
|
||||
Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to that here.
|
||||
|
||||
Let's reuse the command from the previous run and add a few more options:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaTransferCube-v0 \
|
||||
--log_freq=25 \
|
||||
--save_freq=100 \
|
||||
--output_dir=outputs/train/run_resumption
|
||||
```
|
||||
|
||||
Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`).
|
||||
Here we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can showcase resumption. You should be able to see some logging and have a first checkpoint within 1 minute (depending on hardware). Wait for the first checkpoint to happen, you should see a line that looks like this in your terminal:
|
||||
```
|
||||
INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100
|
||||
```
|
||||
Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true
|
||||
```
|
||||
You should see from the logging that your training picks up from where it left off.
|
||||
|
||||
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
|
||||
You could double the number of steps of the previous run with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
## Outputs of a run
|
||||
In the output directory, there will be a folder called `checkpoints` with the following structure:
|
||||
```bash
|
||||
outputs/train/run_resumption/checkpoints
|
||||
├── 000100 # checkpoint_dir for training step 100
|
||||
│ ├── pretrained_model/
|
||||
│ │ ├── config.json # policy config
|
||||
│ │ ├── model.safetensors # policy weights
|
||||
│ │ └── train_config.json # train config
|
||||
│ └── training_state/
|
||||
│ ├── optimizer_param_groups.json # optimizer param groups
|
||||
│ ├── optimizer_state.safetensors # optimizer state
|
||||
│ ├── rng_state.safetensors # rng states
|
||||
│ ├── scheduler_state.json # scheduler state
|
||||
│ └── training_step.json # training step
|
||||
├── 000200
|
||||
└── last -> 000200 # symlink to the last available checkpoint
|
||||
```
|
||||
|
||||
## Fine-tuning a pre-trained policy
|
||||
|
||||
In addition to the features currently in Draccus, we've added a special `.path` argument for the policy, which allows to load a policy as you would with `PreTrainedPolicy.from_pretrained()`. In that case, `path` can be a local directory that contains a checkpoint or a repo_id pointing to a pretrained policy on the hub.
|
||||
|
||||
For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaInsertion-v0
|
||||
```
|
||||
|
||||
When doing so, keep in mind that the features of the fine-tuning dataset would have to match the input/output features of the pretrained policy.
|
||||
|
||||
## Typical logs and metrics
|
||||
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you configured your run correctly. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
or evaluation log:
|
||||
```
|
||||
INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266
|
||||
```
|
||||
|
||||
These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations:
|
||||
|
||||
- `smpl`: number of samples seen during training.
|
||||
- `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task.
|
||||
- `epch`: number of time all unique samples are seen (epoch).
|
||||
@@ -200,14 +230,45 @@ These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here
|
||||
|
||||
Some metrics are useful for initial performance profiling. For example, if you find the current GPU utilization is low via the `nvidia-smi` command and `data_s` sometimes is too high, you may need to modify batch size or number of dataloading workers to accelerate dataloading. We also recommend [pytorch profiler](https://github.com/huggingface/lerobot?tab=readme-ov-file#improve-your-code-with-profiling) for detailed performance probing.
|
||||
|
||||
---
|
||||
## In short
|
||||
|
||||
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||
We'll summarize here the main use cases to remember from this tutorial.
|
||||
|
||||
#### Train a policy from scratch – CLI
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \ # <- select 'act' policy
|
||||
--env.type=pusht \ # <- select 'pusht' environment
|
||||
--dataset.repo_id=lerobot/pusht # <- train on this dataset
|
||||
```
|
||||
|
||||
Please, head on over to our [advanced tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md) to learn more.
|
||||
#### Train a policy from scratch - config file + CLI
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=path/to/pretrained_model \ # <- can also be a repo_id
|
||||
--policy.n_action_steps=80 # <- you may still override values
|
||||
```
|
||||
|
||||
Or in the meantime, happy coding! 🤗
|
||||
#### Resume/continue a training run
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=checkpoint/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000 # <- you can change some training parameters
|
||||
```
|
||||
|
||||
#### Fine-tuning
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaInsertion-v0
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Now that you know the basics of how to train a policy, you might want to know how to apply this knowledge to actual robots, or how to record your own datasets and train policies on your specific task?
|
||||
If that's the case, head over to the next tutorial [`7_get_started_with_real_robot.md`](./7_get_started_with_real_robot.md).
|
||||
|
||||
Or in the meantime, happy training! 🤗
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
This tutorial explains how to resume a training run that you've started with the training script. If you don't know how our training script and configuration system works, please read [4_train_policy_with_script.md](./4_train_policy_with_script.md) first.
|
||||
|
||||
## Basic training resumption
|
||||
|
||||
Let's consider the example of training ACT for one of the ALOHA tasks. Here's a command that can achieve that:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/run_resumption \
|
||||
policy=act \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0 \
|
||||
training.log_freq=25 \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=100
|
||||
```
|
||||
|
||||
Here we're using the default dataset and environment for ACT, and we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can test resumption. You should be able to see some logging and have a first checkpoint within 1 minute. Please interrupt the training after the first checkpoint.
|
||||
|
||||
To resume, all that we have to do is run the training script, providing the run directory, and the resume option:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/run_resumption \
|
||||
resume=true
|
||||
```
|
||||
|
||||
You should see from the logging that your training picks up from where it left off.
|
||||
|
||||
Note that with `resume=true`, the configuration file from the last checkpoint in the training output directory is loaded. So it doesn't matter that we haven't provided all the other configuration parameters from our previous command (although there may be warnings to notify you that your command has a different configuration than than the checkpoint).
|
||||
|
||||
---
|
||||
|
||||
Now you should know how to resume your training run in case it gets interrupted or you want to extend a finished training run.
|
||||
|
||||
Happy coding! 🤗
|
||||
@@ -36,9 +36,14 @@ Using `pip`:
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
Or using `poetry`:
|
||||
Using `poetry`:
|
||||
```bash
|
||||
poetry install --sync --extras "dynamixel"
|
||||
poetry sync --extras "dynamixel"
|
||||
```
|
||||
|
||||
Using `uv`:
|
||||
```bash
|
||||
uv sync --extra "dynamixel"
|
||||
```
|
||||
|
||||
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
|
||||
@@ -54,24 +59,53 @@ Then plug the 12V power supply to the motor bus of the follower arm. It has two
|
||||
|
||||
Finally, connect both arms to your computer via USB. Note that the USB doesn't provide any power, and both arms need to be plugged in with their associated power supply to be detected by your computer.
|
||||
|
||||
*Copy pasting python code*
|
||||
Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
|
||||
|
||||
In the upcoming sections, you'll learn about our classes and functions by running some python code, in an interactive session, or by copy-pasting it in a python file. If this is your first time using the tutorial., we highly recommend going through these steps to get deeper intuition about how things work. Once you're more familiar, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
||||
If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--robot-overrides '~cameras' # do not instantiate the cameras
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
It will automatically:
|
||||
1. Detect and help you correct any motor configuration issues.
|
||||
2. Identify any missing calibrations and initiate the calibration procedure.
|
||||
3. Connect the robot and start teleoperation.
|
||||
1. Identify any missing calibrations and initiate the calibration procedure.
|
||||
2. Connect the robot and start teleoperation.
|
||||
|
||||
### a. Control your motors with DynamixelMotorsBus
|
||||
|
||||
You can use the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) to communicate with the motors connected as a chain to the corresponding USB bus. This class leverages the Python [Dynamixel SDK](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20) to facilitate reading from and writing to the motors.
|
||||
|
||||
**First Configuration of your motors**
|
||||
|
||||
You will need to unplug each motor in turn and run a command the identify the motor. The motor will save its own identification, so you only need to do this once. Start by unplugging all of the motors.
|
||||
|
||||
Do the Leader arm first, as all of its motors are of the same type. Plug in your first motor on your leader arm and run this script to set its ID to 1.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand dynamixel \
|
||||
--model xl330-m288 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Then unplug your first motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand dynamixel \
|
||||
--model xl330-m288 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6.
|
||||
|
||||
The process for the follower arm is almost the same, but the follower arm has two types of motors. For the first two motors, make sure you set the model to `xl430-w250`. _Important: configuring follower motors requires plugging and unplugging power. Make sure you use the 5V power for the XL330s and the 12V power for the XL430s!_
|
||||
|
||||
After all of your motors are configured properly, you're ready to plug them all together in a daisy-chain as shown in the original video.
|
||||
|
||||
**Instantiate the DynamixelMotorsBus**
|
||||
|
||||
To begin, create two instances of the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py), one for each arm, using their corresponding USB ports (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`).
|
||||
@@ -105,10 +139,10 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running this command with your ports:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
sudo chmod 666 /dev/tty.usbmodem575E0032081
|
||||
sudo chmod 666 /dev/tty.usbmodem575E0031751
|
||||
```
|
||||
|
||||
*Listing and Configuring Motors*
|
||||
@@ -117,13 +151,11 @@ Next, you'll need to list the motors for each arm, including their name, index,
|
||||
|
||||
To assign indices to the motors, run this code in an interactive Python session. Replace the `port` values with the ones you identified earlier:
|
||||
```python
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
leader_port = "/dev/tty.usbmodem575E0031751"
|
||||
follower_port = "/dev/tty.usbmodem575E0032081"
|
||||
|
||||
leader_arm = DynamixelMotorsBus(
|
||||
port=leader_port,
|
||||
leader_config = DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl330-m077"),
|
||||
@@ -135,8 +167,8 @@ leader_arm = DynamixelMotorsBus(
|
||||
},
|
||||
)
|
||||
|
||||
follower_arm = DynamixelMotorsBus(
|
||||
port=follower_port,
|
||||
follower_config = DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
@@ -147,45 +179,57 @@ follower_arm = DynamixelMotorsBus(
|
||||
"gripper": (6, "xl330-m288"),
|
||||
},
|
||||
)
|
||||
|
||||
leader_arm = DynamixelMotorsBus(leader_config)
|
||||
follower_arm = DynamixelMotorsBus(follower_config)
|
||||
```
|
||||
|
||||
*Updating the YAML Configuration File*
|
||||
IMPORTANTLY: Now that you have your ports, update [`KochRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("koch")
|
||||
@dataclass
|
||||
class KochRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/koch"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
Next, update the port values in the YAML configuration file for the Koch robot at [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml) with the ports you've identified:
|
||||
```yaml
|
||||
[...]
|
||||
robot_type: koch
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751 # <- Update
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
shoulder_lift: [2, "xl330-m077"]
|
||||
elbow_flex: [3, "xl330-m077"]
|
||||
wrist_flex: [4, "xl330-m077"]
|
||||
wrist_roll: [5, "xl330-m077"]
|
||||
gripper: [6, "xl330-m077"]
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081 # <- Update
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
shoulder_lift: [2, "xl430-w250"]
|
||||
elbow_flex: [3, "xl330-m288"]
|
||||
wrist_flex: [4, "xl330-m288"]
|
||||
wrist_roll: [5, "xl330-m288"]
|
||||
gripper: [6, "xl330-m288"]
|
||||
[...]
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0085511", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl330-m077"],
|
||||
"shoulder_lift": [2, "xl330-m077"],
|
||||
"elbow_flex": [3, "xl330-m077"],
|
||||
"wrist_flex": [4, "xl330-m077"],
|
||||
"wrist_roll": [5, "xl330-m077"],
|
||||
"gripper": [6, "xl330-m077"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
"elbow_flex": [3, "xl330-m288"],
|
||||
"wrist_flex": [4, "xl330-m288"],
|
||||
"wrist_roll": [5, "xl330-m288"],
|
||||
"gripper": [6, "xl330-m288"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
Don't forget to set `robot_type: aloha` if you follow this tutorial with [Aloha bimanual robot](aloha-2.github.io) instead of Koch v1.1
|
||||
|
||||
This configuration file is used to instantiate your robot across all scripts. We'll cover how this works later on.
|
||||
|
||||
**Connect and Configure your Motors**
|
||||
|
||||
Before you can start using your motors, you'll need to configure them to ensure proper communication. When you first connect the motors, the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) automatically detects any mismatch between the current motor indices (factory set to `1`) and the specified indices (e.g., `1, 2, 3, 4, 5, 6`). This triggers a configuration procedure that requires you to unplug the power cord and motors, then reconnect each motor sequentially, starting from the one closest to the bus.
|
||||
@@ -312,27 +356,27 @@ Alternatively, you can unplug the power cord, which will automatically disable t
|
||||
|
||||
**Instantiate the ManipulatorRobot**
|
||||
|
||||
Before you can teleoperate your robot, you need to instantiate the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) using the previously defined `leader_arm` and `follower_arm`.
|
||||
Before you can teleoperate your robot, you need to instantiate the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) using the previously defined `leader_config` and `follower_config`.
|
||||
|
||||
For the Koch v1.1 robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_arm}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_arm, "right": right_leader_arm},`. Same thing for the follower arms.
|
||||
For the Koch v1.1 robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_config}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_config, "right": right_leader_config},`. Same thing for the follower arms.
|
||||
|
||||
You also need to provide a path to a calibration directory, such as `calibration_dir=".cache/calibration/koch"`. More on this in the next section.
|
||||
|
||||
Run the following code to instantiate your manipulator robot:
|
||||
```python
|
||||
from lerobot.common.robot_devices.robots.configs import KochRobotConfig
|
||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||
|
||||
robot = ManipulatorRobot(
|
||||
robot_type="koch",
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
robot_config = KochRobotConfig(
|
||||
leader_arms={"main": leader_config},
|
||||
follower_arms={"main": follower_config},
|
||||
cameras={}, # We don't use any camera for now
|
||||
)
|
||||
robot = ManipulatorRobot(robot_config)
|
||||
```
|
||||
|
||||
The `robot_type="koch"` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger.
|
||||
The `KochRobotConfig` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger.
|
||||
|
||||
For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `robot_type="aloha"` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected. If you need to run manual calibration, simply update `calibration_dir` to `.cache/calibration/aloha`.
|
||||
For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `AlohaRobotConfig` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected for Aloha.
|
||||
|
||||
**Calibrate and Connect the ManipulatorRobot**
|
||||
|
||||
@@ -342,19 +386,19 @@ When you connect your robot for the first time, the [`ManipulatorRobot`](../lero
|
||||
|
||||
Here are the positions you'll move the follower arm to:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| <img src="../media/koch/follower_zero.webp?raw=true" alt="Koch v1.1 follower arm zero position" title="Koch v1.1 follower arm zero position" style="width:100%;"> | <img src="../media/koch/follower_rotated.webp?raw=true" alt="Koch v1.1 follower arm rotated position" title="Koch v1.1 follower arm rotated position" style="width:100%;"> | <img src="../media/koch/follower_rest.webp?raw=true" alt="Koch v1.1 follower arm rest position" title="Koch v1.1 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
And here are the corresponding positions for the leader arm:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| <img src="../media/koch/leader_zero.webp?raw=true" alt="Koch v1.1 leader arm zero position" title="Koch v1.1 leader arm zero position" style="width:100%;"> | <img src="../media/koch/leader_rotated.webp?raw=true" alt="Koch v1.1 leader arm rotated position" title="Koch v1.1 leader arm rotated position" style="width:100%;"> | <img src="../media/koch/leader_rest.webp?raw=true" alt="Koch v1.1 leader arm rest position" title="Koch v1.1 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
||||
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to mesure if the values changed negatively or positively.
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
|
||||
|
||||
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
|
||||
|
||||
@@ -579,9 +623,11 @@ Note: Some cameras may take a few seconds to warm up, and the first frame might
|
||||
|
||||
Finally, run this code to instantiate and connectyour camera:
|
||||
```python
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
camera = OpenCVCamera(camera_index=0)
|
||||
config = OpenCVCameraConfig(camera_index=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
|
||||
@@ -603,7 +649,7 @@ uint8
|
||||
|
||||
With certain camera, you can also specify additional parameters like frame rate, resolution, and color mode during instantiation. For instance:
|
||||
```python
|
||||
camera = OpenCVCamera(camera_index=0, fps=30, width=640, height=480)
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=30, width=640, height=480)
|
||||
```
|
||||
|
||||
If the provided arguments are not compatible with the camera, an exception will be raised.
|
||||
@@ -617,18 +663,20 @@ camera.disconnect()
|
||||
|
||||
**Instantiate your robot with cameras**
|
||||
|
||||
Additionaly, you can set up your robot to work with your cameras.
|
||||
Additionally, you can set up your robot to work with your cameras.
|
||||
|
||||
Modify the following Python code with the appropriate camera names and configurations:
|
||||
```python
|
||||
robot = ManipulatorRobot(
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
cameras={
|
||||
"laptop": OpenCVCamera(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCamera(1, fps=30, width=640, height=480),
|
||||
},
|
||||
KochRobotConfig(
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
cameras={
|
||||
"laptop": OpenCVCameraConfig(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
|
||||
},
|
||||
)
|
||||
)
|
||||
robot.connect()
|
||||
```
|
||||
@@ -652,39 +700,20 @@ torch.Size([3, 480, 640])
|
||||
255
|
||||
```
|
||||
|
||||
Also, update the following lines of the yaml file for Koch robot [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml) with the names and configurations of your cameras:
|
||||
```yaml
|
||||
[...]
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
```
|
||||
### d. Use `control_robot.py` and our `teleoperate` function
|
||||
|
||||
This file is used to instantiate your robot in all our scripts. We will explain how this works in the next section.
|
||||
|
||||
### d. Use `koch.yaml` and our `teleoperate` function
|
||||
|
||||
Instead of manually running the python code in a terminal window, you can use [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to instantiate your robot by providing the path to the robot yaml file (e.g. [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml)) and control your robot with various modes as explained next.
|
||||
Instead of manually running the python code in a terminal window, you can use [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to instantiate your robot by providing the robot configurations via command line and control your robot with various modes as explained next.
|
||||
|
||||
Try running this code to teleoperate your robot (if you dont have a camera, keep reading):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtRfoll: 0.19 (5239.0hz)
|
||||
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtWfoll: 0.19 (5239.0hz)
|
||||
```
|
||||
|
||||
It contains
|
||||
@@ -694,21 +723,12 @@ It contains
|
||||
- `dtRlead: 4.93 (203.0hz)` which is the number of milliseconds it took to read the position of the leader arm using `leader_arm.read("Present_Position")`.
|
||||
- `dtWfoll: 0.22 (4446.9hz)` which is the number of milliseconds it took to set a new goal position for the follower arm using `follower_arm.write("Goal_position", leader_pos)` ; note that writing is done asynchronously so it takes less time than reading.
|
||||
|
||||
Note: you can override any entry in the yaml file using `--robot-overrides` and the [hydra.cc](https://hydra.cc/docs/advanced/override_grammar/basic) syntax. If needed, you can override the ports like this:
|
||||
Importantly: If you don't have any camera, you can remove them dynamically with this [draccus](https://github.com/dlwh/draccus) syntax `--robot.cameras='{}'`:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--robot-overrides \
|
||||
leader_arms.main.port=/dev/tty.usbmodem575E0031751 \
|
||||
follower_arms.main.port=/dev/tty.usbmodem575E0032081
|
||||
```
|
||||
|
||||
Importantly: If you don't have any camera, you can remove them dynamically with this [hydra.cc](https://hydra.cc/docs/advanced/override_grammar/basic) syntax `'~cameras'`:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--robot-overrides \
|
||||
'~cameras'
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
We advise to create a new yaml file when the command becomes too long.
|
||||
@@ -744,23 +764,23 @@ for _ in range(record_time_s * fps):
|
||||
|
||||
Importantly, many utilities are still missing. For instance, if you have cameras, you will need to save the images on disk to not go out of RAM, and to do so in threads to not slow down communication with your robot. Also, you will need to store your data in a format optimized for training and web sharing like [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py). More on this in the next section.
|
||||
|
||||
### a. Use `koch.yaml` and the `record` function
|
||||
### a. Use the `record` function
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to achieve efficient data recording. It encompasses many recording utilities:
|
||||
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of recording.
|
||||
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording.
|
||||
2. Video streams from cameras are displayed in window so that you can verify them.
|
||||
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--push-to-hub 0` is provided).
|
||||
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again. You can also use `--force-override 1` to start recording from scratch.
|
||||
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided).
|
||||
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch.
|
||||
5. Set the flow of data recording using command line arguments:
|
||||
- `--warmup-time-s` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
||||
- `--episode-time-s` defines the number of seconds for data recording for each episode (60 seconds by default).
|
||||
- `--reset-time-s` defines the number of seconds for resetting the environment after each episode (60 seconds by default).
|
||||
- `--num-episodes` defines the number of episodes to record (50 by default).
|
||||
- `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
||||
- `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
|
||||
- `--control.reset_time_s=60` defines the number of seconds for resetting the environment after each episode (60 seconds by default).
|
||||
- `--control.num_episodes=50` defines the number of episodes to record (50 by default).
|
||||
6. Control the flow during data recording using keyboard keys:
|
||||
- Press right arrow `->` at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
|
||||
- Press left arrow `<-` at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
|
||||
- Press escape `ESC` at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
|
||||
7. Similarly to `teleoperate`, you can also use `--robot-path` and `--robot-overrides` to specify your robots.
|
||||
7. Similarly to `teleoperate`, you can also use the command line to override anything.
|
||||
|
||||
Before trying `record`, if you want to push your dataset to the hub, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
@@ -771,27 +791,29 @@ Also, store your Hugging Face repository name in a variable (e.g. `cadene` or `l
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
If you don't want to push to hub, use `--push-to-hub 0`.
|
||||
If you don't want to push to hub, use `--control.push_to_hub=false`.
|
||||
|
||||
Now run this to record 2 episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/koch_test \
|
||||
--tags tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 30 \
|
||||
--num-episodes 2
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=record \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/koch_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
|
||||
This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
|
||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
Remember to add `--robot-overrides '~cameras'` if you don't have any cameras and you still use the default `koch.yaml` configuration.
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 15:02:58 ol_robot.py:219 dt:33.34 (30.0hz) dtRlead: 5.06 (197.5hz) dtWfoll: 0.25 (3963.7hz) dtRfoll: 6.22 (160.7hz) dtRlaptop: 32.57 (30.7hz) dtRphone: 33.84 (29.5hz)
|
||||
@@ -803,8 +825,8 @@ It contains:
|
||||
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
|
||||
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
|
||||
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
|
||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchrously.
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchrously.
|
||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously.
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||
|
||||
Troubleshooting:
|
||||
- On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda:
|
||||
@@ -824,7 +846,7 @@ At the end of data recording, your dataset will be uploaded on your Hugging Face
|
||||
echo https://huggingface.co/datasets/${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
### b. Advices for recording dataset
|
||||
### b. Advice for recording dataset
|
||||
|
||||
Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings.
|
||||
|
||||
@@ -842,6 +864,8 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
Note: You might need to add `--local-files-only 1` if your dataset was not uploaded to hugging face hub.
|
||||
|
||||
This will launch a local web server that looks like this:
|
||||
<div style="text-align:center;">
|
||||
<img src="../media/tutorial/visualize_dataset_html.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="100%">
|
||||
@@ -853,11 +877,12 @@ A useful feature of [`lerobot/scripts/control_robot.py`](../lerobot/scripts/cont
|
||||
|
||||
To replay the first episode of the dataset you just recorded, run the following command:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/koch_test \
|
||||
--episode 0
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/koch_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
|
||||
@@ -869,51 +894,18 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/koch_test \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
hydra.run.dir=outputs/train/act_koch_test \
|
||||
hydra.job.name=act_koch_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
--dataset.repo_id=${HF_USER}/koch_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_koch_test \
|
||||
--job_name=act_koch_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/koch_test`.
|
||||
2. We provided the policy with `policy=act_koch_real`. This loads configurations from [`lerobot/configs/policy/act_koch_real.yaml`](../lerobot/configs/policy/act_koch_real.yaml). Importantly, this policy uses 2 cameras as input `laptop` and `phone`. If your dataset has different cameras, update the yaml file to account for it in the following parts:
|
||||
```yaml
|
||||
...
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
...
|
||||
input_shapes:
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
...
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
...
|
||||
```
|
||||
3. We provided an environment as argument with `env=koch_real`. This loads configurations from [`lerobot/configs/env/koch_real.yaml`](../lerobot/configs/env/koch_real.yaml). It looks like
|
||||
```yaml
|
||||
fps: 30
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
```
|
||||
It should match your dataset (e.g. `fps: 30`) and your robot (e.g. `state_dim: 6` and `action_dim: 6`). We are still working on simplifying this in future versions of `lerobot`.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
@@ -978,34 +970,36 @@ for _ in range(inference_time_s * fps):
|
||||
busy_wait(1 / fps - dt_s)
|
||||
```
|
||||
|
||||
### a. Use `koch.yaml` and our `record` function
|
||||
### a. Use our `record` function
|
||||
|
||||
Ideally, when controlling your robot with your neural network, you would want to record evaluation episodes and to be able to visualize them later on, or even train on them like in Reinforcement Learning. This pretty much corresponds to recording a new dataset but with a neural network providing the actions instead of teleoperation.
|
||||
|
||||
To this end, you can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/eval_koch_test \
|
||||
--tags tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 30 \
|
||||
--num-episodes 10 \
|
||||
-p outputs/train/act_koch_test/checkpoints/last/pretrained_model
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/eval_act_koch_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_koch_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_koch_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_koch_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_koch_test`).
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_koch_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_koch_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_koch_test`).
|
||||
|
||||
### b. Visualize evaluation afterwards
|
||||
|
||||
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id ${HF_USER}/eval_koch_test
|
||||
--repo-id ${HF_USER}/eval_act_koch_test
|
||||
```
|
||||
|
||||
## 6. Next step
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||
@@ -34,14 +48,10 @@ transforms = v2.Compose(
|
||||
)
|
||||
|
||||
# Create another LeRobotDataset with the defined transformations
|
||||
transformed_dataset = LeRobotDataset(
|
||||
dataset_repo_id, episodes=[0], image_transforms=transforms
|
||||
)
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][
|
||||
transformed_dataset.meta.camera_keys[0]
|
||||
]
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
|
||||
|
||||
# Create a directory to store output images
|
||||
output_dir = Path("outputs/image_transforms")
|
||||
@@ -1,87 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Change the seed to match what PushT eval uses
|
||||
# (to avoid evaluating on seeds used for generating the training data).
|
||||
seed: 100000
|
||||
# Change the dataset repository to the PushT one.
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
override_dataset_stats:
|
||||
observation.image:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.image: mean_std
|
||||
# Use min_max normalization just because it's more standard.
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
# Use min_max normalization just because it's more standard.
|
||||
action: min_max
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -1,70 +0,0 @@
|
||||
In this tutorial we will learn how to adapt a policy configuration to be compatible with a new environment and dataset. As a concrete example, we will adapt the default configuration for ACT to be compatible with the PushT environment and dataset.
|
||||
|
||||
If you haven't already read our tutorial on the [training script and configuration tooling](../4_train_policy_with_script.md) please do so prior to tackling this tutorial.
|
||||
|
||||
Let's get started!
|
||||
|
||||
Suppose we want to train ACT for PushT. Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
|
||||
```
|
||||
|
||||
We need to adapt the parameters of the ACT policy configuration to the PushT environment. The most important ones are the image keys.
|
||||
|
||||
ALOHA's datasets and environments typically use a variable number of cameras. In `lerobot/configs/policy/act.yaml` you may notice two relevant sections. Here we show you the minimal diff needed to adjust to PushT:
|
||||
|
||||
```diff
|
||||
override_dataset_stats:
|
||||
- observation.images.top:
|
||||
+ observation.image:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
policy:
|
||||
input_shapes:
|
||||
- observation.images.top: [3, 480, 640]
|
||||
+ observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
input_normalization_modes:
|
||||
- observation.images.top: mean_std
|
||||
+ observation.image: mean_std
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
```
|
||||
|
||||
Here we've accounted for the following:
|
||||
- PushT uses "observation.image" for its image key.
|
||||
- PushT provides smaller images.
|
||||
|
||||
_Side note: technically we could override these via the CLI, but with many changes it gets a bit messy, and we also have a bit of a challenge in that we're using `.` in our observation keys which is treated by Hydra as a hierarchical separator_.
|
||||
|
||||
For your convenience, we provide [`act_pusht.yaml`](./act_pusht.yaml) in this directory. It contains the diff above, plus some other (optional) ones that are explained within. Please copy it into `lerobot/configs/policy` with:
|
||||
|
||||
```bash
|
||||
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/act_pusht.yaml
|
||||
```
|
||||
|
||||
(remember from a [previous tutorial](../4_train_policy_with_script.md) that Hydra will look in the `lerobot/configs` directory). Now try running the following.
|
||||
|
||||
<!-- Note to contributor: are you changing this command? Note that it's tested in `Makefile`, so change it there too! -->
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act_pusht env=pusht
|
||||
```
|
||||
|
||||
Notice that this is much the same as the command that failed at the start of the tutorial, only:
|
||||
- Now we are using `policy=act_pusht` to point to our new configuration file.
|
||||
- We can drop `dataset_repo_id=lerobot/pusht` as the change is incorporated in our new configuration file.
|
||||
|
||||
Hurrah! You're now training ACT for the PushT environment.
|
||||
|
||||
---
|
||||
|
||||
The bottom line of this tutorial is that when training policies for different environments and datasets you will need to understand what parts of the policy configuration are specific to those and make changes accordingly.
|
||||
|
||||
Happy coding! 🤗
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
|
||||
|
||||
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
|
||||
@@ -9,100 +23,82 @@ on the target environment, whether that be in simulation or the real world.
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
def main():
|
||||
device = torch.device("cuda")
|
||||
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = "lerobot/diffusion_pusht"
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
# Set up the dataset.
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.4,
|
||||
],
|
||||
}
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load dataset metadata
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
# - Calculate train and val episodes
|
||||
total_episodes = dataset_metadata.total_episodes
|
||||
episodes = list(range(dataset_metadata.total_episodes))
|
||||
num_train_episodes = math.floor(total_episodes * 90 / 100)
|
||||
train_episodes = episodes[:num_train_episodes]
|
||||
val_episodes = episodes[num_train_episodes:]
|
||||
print(f"Number of episodes in full dataset: {total_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
|
||||
# - Load train an val datasets
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps
|
||||
)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
# Set up the dataset.
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# Create dataloader for evaluation.
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=False,
|
||||
)
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load dataset metadata
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
# - Calculate train and val episodes
|
||||
total_episodes = dataset_metadata.total_episodes
|
||||
episodes = list(range(dataset_metadata.total_episodes))
|
||||
num_train_episodes = math.floor(total_episodes * 90 / 100)
|
||||
train_episodes = episodes[:num_train_episodes]
|
||||
val_episodes = episodes[num_train_episodes:]
|
||||
print(f"Number of episodes in full dataset: {total_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
|
||||
# - Load train an val datasets
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
|
||||
# Run validation loop.
|
||||
loss_cumsum = 0
|
||||
n_examples_evaluated = 0
|
||||
for batch in val_dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
# Create dataloader for evaluation.
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
loss_cumsum += output_dict["loss"].item()
|
||||
n_examples_evaluated += batch["index"].shape[0]
|
||||
# Run validation loop.
|
||||
loss_cumsum = 0
|
||||
n_examples_evaluated = 0
|
||||
for batch in val_dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
|
||||
# Calculate the average loss over the validation set.
|
||||
average_loss = loss_cumsum / n_examples_evaluated
|
||||
loss_cumsum += loss.item()
|
||||
n_examples_evaluated += batch["index"].shape[0]
|
||||
|
||||
print(f"Average loss on validation set: {average_loss:.4f}")
|
||||
# Calculate the average loss over the validation set.
|
||||
average_loss = loss_cumsum / n_examples_evaluated
|
||||
|
||||
print(f"Average loss on validation set: {average_loss:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,10 +1,25 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
|
||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
||||
@@ -44,7 +59,7 @@ PUSHT_FEATURES = {
|
||||
"dtype": None,
|
||||
"shape": (3, 96, 96),
|
||||
"names": [
|
||||
"channel",
|
||||
"channels",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
@@ -69,9 +84,7 @@ def load_raw_dataset(zarr_path: Path):
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
|
||||
)
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
@@ -83,9 +96,7 @@ def calculate_coverage(zarr_data):
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
|
||||
)
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
block_pos = zarr_data["state"][:, 2:4]
|
||||
@@ -93,9 +104,9 @@ def calculate_coverage(zarr_data):
|
||||
|
||||
num_frames = len(block_pos)
|
||||
|
||||
coverage = np.zeros((num_frames,))
|
||||
coverage = np.zeros((num_frames,), dtype=np.float32)
|
||||
# 8 keypoints with 2 coords each
|
||||
keypoints = np.zeros((num_frames, 16))
|
||||
keypoints = np.zeros((num_frames, 16), dtype=np.float32)
|
||||
|
||||
# Set x, y, theta (in radians)
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
||||
@@ -115,15 +126,13 @@ def calculate_coverage(zarr_data):
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(
|
||||
space, block_pos[i].tolist(), block_angle[i].item()
|
||||
)
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage[i] = intersection_area / goal_area
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
keypoints[i] = PushTEnv.get_keypoints(block_shapes).flatten()
|
||||
|
||||
return coverage, keypoints
|
||||
|
||||
@@ -140,8 +149,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
if mode not in ["video", "image", "keypoints"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
if (HF_LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||
@@ -154,6 +163,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
action = zarr_data["action"][:]
|
||||
image = zarr_data["img"] # (b, h, w, c)
|
||||
|
||||
if image.dtype == np.float32 and image.max() == np.float32(255):
|
||||
# HACK: images are loaded as float32 but they actually encode uint8 data
|
||||
image = image.astype(np.uint8)
|
||||
|
||||
episode_data_index = {
|
||||
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
||||
"to": zarr_data.meta["episode_ends"],
|
||||
@@ -181,28 +194,30 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
i = from_idx + frame_idx
|
||||
idx = i + (frame_idx < num_frames - 1)
|
||||
frame = {
|
||||
"action": torch.from_numpy(action[i]),
|
||||
"action": action[i],
|
||||
# Shift reward and success by +1 until the last item of the episode
|
||||
"next.reward": reward[i + (frame_idx < num_frames - 1)],
|
||||
"next.success": success[i + (frame_idx < num_frames - 1)],
|
||||
"next.reward": reward[idx : idx + 1],
|
||||
"next.success": success[idx : idx + 1],
|
||||
"task": PUSHT_TASK,
|
||||
}
|
||||
|
||||
frame["observation.state"] = torch.from_numpy(agent_pos[i])
|
||||
frame["observation.state"] = agent_pos[i]
|
||||
|
||||
if mode == "keypoints":
|
||||
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
|
||||
frame["observation.environment_state"] = keypoints[i]
|
||||
else:
|
||||
frame["observation.image"] = torch.from_numpy(image[i])
|
||||
frame["observation.image"] = image[i]
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=PUSHT_TASK)
|
||||
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -224,5 +239,5 @@ if __name__ == "__main__":
|
||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
||||
|
||||
# Uncomment if you want to load the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
|
||||
# dataset = LeRobotDataset(repo_id=repo_id)
|
||||
# breakpoint()
|
||||
|
||||
@@ -58,7 +58,6 @@ available_tasks_per_env = {
|
||||
],
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
@@ -86,23 +85,6 @@ available_datasets_per_env = {
|
||||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
"dora_aloha_real": [
|
||||
"lerobot/aloha_static_battery",
|
||||
"lerobot/aloha_static_candy",
|
||||
"lerobot/aloha_static_coffee",
|
||||
"lerobot/aloha_static_coffee_new",
|
||||
"lerobot/aloha_static_cups_open",
|
||||
"lerobot/aloha_static_fork_pick_up",
|
||||
"lerobot/aloha_static_pingpong_test",
|
||||
"lerobot/aloha_static_pro_pencil",
|
||||
"lerobot/aloha_static_screw_driver",
|
||||
"lerobot/aloha_static_tape",
|
||||
"lerobot/aloha_static_thread_velcro",
|
||||
"lerobot/aloha_static_towel",
|
||||
"lerobot/aloha_static_vinh_cup",
|
||||
"lerobot/aloha_static_vinh_cup_left",
|
||||
"lerobot/aloha_static_ziploc_slide",
|
||||
],
|
||||
}
|
||||
|
||||
available_real_world_datasets = [
|
||||
@@ -182,11 +164,7 @@ available_real_world_datasets = [
|
||||
]
|
||||
|
||||
available_datasets = sorted(
|
||||
set(
|
||||
itertools.chain(
|
||||
*available_datasets_per_env.values(), available_real_world_datasets
|
||||
)
|
||||
)
|
||||
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/common/policies`
|
||||
@@ -225,16 +203,11 @@ available_policies_per_env = {
|
||||
"xarm": ["tdmpc"],
|
||||
"koch_real": ["act_koch_real"],
|
||||
"aloha_real": ["act_aloha_real"],
|
||||
"dora_aloha_real": ["act_aloha_real"],
|
||||
}
|
||||
|
||||
env_task_pairs = [
|
||||
(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks
|
||||
]
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
env_dataset_pairs = [
|
||||
(env, dataset)
|
||||
for env, datasets in available_datasets_per_env.items()
|
||||
for dataset in datasets
|
||||
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
|
||||
]
|
||||
env_dataset_policy_triplets = [
|
||||
(env, dataset, policy)
|
||||
|
||||
4
lerobot/common/cameras/__init__.py
Normal file
4
lerobot/common/cameras/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig
|
||||
|
||||
__all__ = ["Camera", "CameraConfig"]
|
||||
25
lerobot/common/cameras/camera.py
Normal file
25
lerobot/common/cameras/camera.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import abc
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Camera(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_read(self) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self):
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
11
lerobot/common/cameras/configs.py
Normal file
11
lerobot/common/cameras/configs.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
4
lerobot/common/cameras/intel/__init__.py
Normal file
4
lerobot/common/cameras/intel/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .camera_realsense import RealSenseCamera
|
||||
from .configuration_realsense import RealSenseCameraConfig
|
||||
|
||||
__all__ = ["RealSenseCamera", "RealSenseCameraConfig"]
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file contains utilities for recording frames from Intel Realsense cameras.
|
||||
"""
|
||||
@@ -11,20 +25,21 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.utils.robot_utils import (
|
||||
busy_wait,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_realsense import RealSenseCameraConfig
|
||||
|
||||
SERIAL_NUMBER_INDEX = 1
|
||||
|
||||
|
||||
@@ -34,7 +49,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
|
||||
connected to the computer.
|
||||
"""
|
||||
if mock:
|
||||
import tests.mock_pyrealsense2 as rs
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
else:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
@@ -65,9 +80,7 @@ def save_image(img_array, serial_number, frame_index, images_dir):
|
||||
img.save(str(path), quality=100)
|
||||
logging.info(f"Saved image: {path}")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
|
||||
)
|
||||
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
|
||||
|
||||
|
||||
def save_images_from_cameras(
|
||||
@@ -88,7 +101,7 @@ def save_images_from_cameras(
|
||||
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
||||
|
||||
if mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -96,12 +109,11 @@ def save_images_from_cameras(
|
||||
cameras = []
|
||||
for cam_sn in serial_numbers:
|
||||
print(f"{cam_sn=}")
|
||||
camera = IntelRealSenseCamera(
|
||||
cam_sn, fps=fps, width=width, height=height, mock=mock
|
||||
)
|
||||
config = RealSenseCameraConfig(serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
||||
f"RealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
@@ -144,9 +156,7 @@ def save_images_from_cameras(
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
|
||||
print(
|
||||
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
|
||||
)
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
|
||||
frame_index += 1
|
||||
finally:
|
||||
@@ -155,62 +165,11 @@ def save_images_from_cameras(
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntelRealSenseCameraConfig:
|
||||
class RealSenseCamera(Camera):
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
IntelRealSenseCameraConfig(30, 640, 480)
|
||||
IntelRealSenseCameraConfig(60, 640, 480)
|
||||
IntelRealSenseCameraConfig(90, 640, 480)
|
||||
IntelRealSenseCameraConfig(30, 1280, 720)
|
||||
IntelRealSenseCameraConfig(30, 640, 480, use_depth=True)
|
||||
IntelRealSenseCameraConfig(30, 640, 480, rotation=90)
|
||||
```
|
||||
"""
|
||||
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
use_depth: bool = False
|
||||
force_hardware_reset: bool = True
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = (
|
||||
self.fps is not None or self.width is not None or self.height is not None
|
||||
)
|
||||
at_least_one_is_none = (
|
||||
self.fps is None or self.width is None or self.height is None
|
||||
)
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
|
||||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
|
||||
|
||||
class IntelRealSenseCamera:
|
||||
"""
|
||||
The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
|
||||
The RealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
|
||||
- is instantiated with the serial number of the camera - won't randomly change as it can be the case of OpenCVCamera for Linux,
|
||||
- can also be instantiated with the camera's name — if it's unique — using IntelRealSenseCamera.init_from_name(),
|
||||
- can also be instantiated with the camera's name — if it's unique — using RealSenseCamera.init_from_name(),
|
||||
- depth map can be returned.
|
||||
|
||||
To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera:
|
||||
@@ -218,36 +177,38 @@ class IntelRealSenseCamera:
|
||||
python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras
|
||||
```
|
||||
|
||||
When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
When an RealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
of the given camera will be used.
|
||||
|
||||
Example of usage:
|
||||
Example of instantiating with a serial number:
|
||||
```python
|
||||
# Instantiate with its serial number
|
||||
camera = IntelRealSenseCamera(128422271347)
|
||||
# Or by its name if it's unique
|
||||
camera = IntelRealSenseCamera.init_from_name("Intel RealSense D405")
|
||||
from lerobot.common.robot_devices.cameras.configs import RealSenseCameraConfig
|
||||
|
||||
config = RealSenseCameraConfig(serial_number=128422271347)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
# when done using the camera, consider disconnecting
|
||||
camera.disconnect()
|
||||
```
|
||||
|
||||
Example of instantiating with a name if it's unique:
|
||||
```
|
||||
config = RealSenseCameraConfig(name="Intel RealSense D405")
|
||||
```
|
||||
|
||||
Example of changing default fps, width, height and color_mode:
|
||||
```python
|
||||
camera = IntelRealSenseCamera(serial_number, fps=30, width=1280, height=720)
|
||||
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
|
||||
|
||||
camera = IntelRealSenseCamera(serial_number, fps=90, width=640, height=480)
|
||||
camera = connect()
|
||||
|
||||
camera = IntelRealSenseCamera(serial_number, fps=90, width=640, height=480, color_mode="bgr")
|
||||
camera = connect()
|
||||
config = RealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
|
||||
config = RealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
|
||||
config = RealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
|
||||
# Note: might error out upon `camera.connect()` if these settings are not compatible with the camera
|
||||
```
|
||||
|
||||
Example of returning depth:
|
||||
```python
|
||||
camera = IntelRealSenseCamera(serial_number, use_depth=True)
|
||||
config = RealSenseCameraConfig(serial_number=128422271347, use_depth=True)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect()
|
||||
color_image, depth_map = camera.read()
|
||||
```
|
||||
@@ -255,20 +216,27 @@ class IntelRealSenseCamera:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
serial_number: int,
|
||||
config: IntelRealSenseCameraConfig | None = None,
|
||||
**kwargs,
|
||||
config: RealSenseCameraConfig,
|
||||
):
|
||||
if config is None:
|
||||
config = IntelRealSenseCameraConfig()
|
||||
self.config = config
|
||||
if config.name is not None:
|
||||
self.serial_number = self.find_serial_number_from_name(config.name)
|
||||
else:
|
||||
self.serial_number = config.serial_number
|
||||
|
||||
# Overwrite the config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
self.capture_height = config.height
|
||||
|
||||
# If rotated by ±90, swap width and height.
|
||||
if config.rotation in [-90, 90]:
|
||||
self.width = config.height
|
||||
self.height = config.width
|
||||
else:
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
|
||||
self.serial_number = serial_number
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.use_depth = config.use_depth
|
||||
@@ -284,11 +252,10 @@ class IntelRealSenseCamera:
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# TODO(alibets): Do we keep original width/height or do we define them after rotation?
|
||||
self.rotation = None
|
||||
if config.rotation == -90:
|
||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
@@ -297,10 +264,7 @@ class IntelRealSenseCamera:
|
||||
elif config.rotation == 180:
|
||||
self.rotation = cv2.ROTATE_180
|
||||
|
||||
@classmethod
|
||||
def init_from_name(
|
||||
cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs
|
||||
):
|
||||
def find_serial_number_from_name(self, name):
|
||||
camera_infos = find_cameras()
|
||||
camera_names = [cam["name"] for cam in camera_infos]
|
||||
this_name_count = Counter(camera_names)[name]
|
||||
@@ -310,45 +274,35 @@ class IntelRealSenseCamera:
|
||||
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
|
||||
)
|
||||
|
||||
name_to_serial_dict = {
|
||||
cam["name"]: cam["serial_number"] for cam in camera_infos
|
||||
}
|
||||
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
|
||||
cam_sn = name_to_serial_dict[name]
|
||||
|
||||
if config is None:
|
||||
config = IntelRealSenseCameraConfig()
|
||||
|
||||
# Overwrite the config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
return cls(serial_number=cam_sn, config=config, **kwargs)
|
||||
return cam_sn
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"IntelRealSenseCamera({self.serial_number}) is already connected."
|
||||
)
|
||||
raise DeviceAlreadyConnectedError(f"RealSenseCamera({self.serial_number}) is already connected.")
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_pyrealsense2 as rs
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
else:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
config = rs.config()
|
||||
config.enable_device(str(self.serial_number))
|
||||
|
||||
if self.fps and self.width and self.height:
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
# TODO(rcadene): can we set rgb8 directly?
|
||||
config.enable_stream(
|
||||
rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps
|
||||
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.color)
|
||||
|
||||
if self.use_depth:
|
||||
if self.fps and self.width and self.height:
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
config.enable_stream(
|
||||
rs.stream.depth, self.width, self.height, rs.format.z16, self.fps
|
||||
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.depth)
|
||||
@@ -373,7 +327,7 @@ class IntelRealSenseCamera:
|
||||
"To find the serial number you should use, run `python lerobot/common/robot_devices/cameras/intelrealsense.py`."
|
||||
)
|
||||
|
||||
raise OSError(f"Can't access IntelRealSenseCamera({self.serial_number}).")
|
||||
raise OSError(f"Can't access RealSenseCamera({self.serial_number}).")
|
||||
|
||||
color_stream = profile.get_stream(rs.stream.color)
|
||||
color_profile = color_stream.as_video_stream_profile()
|
||||
@@ -382,31 +336,27 @@ class IntelRealSenseCamera:
|
||||
actual_height = color_profile.height()
|
||||
|
||||
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
||||
if self.fps is not None and not math.isclose(
|
||||
self.fps, actual_fps, rel_tol=1e-3
|
||||
):
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
f"Can't set {self.fps=} for RealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width is not None and self.width != actual_width:
|
||||
if self.capture_width is not None and self.capture_width != actual_width:
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||
f"Can't set {self.capture_width=} for RealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.height is not None and self.height != actual_height:
|
||||
if self.capture_height is not None and self.capture_height != actual_height:
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||
f"Can't set {self.capture_height=} for RealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = round(actual_fps)
|
||||
self.width = round(actual_width)
|
||||
self.height = round(actual_height)
|
||||
self.capture_width = round(actual_width)
|
||||
self.capture_height = round(actual_height)
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def read(
|
||||
self, temporary_color: str | None = None
|
||||
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
|
||||
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
|
||||
|
||||
@@ -417,12 +367,12 @@ class IntelRealSenseCamera:
|
||||
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
raise DeviceNotConnectedError(
|
||||
f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -433,15 +383,11 @@ class IntelRealSenseCamera:
|
||||
color_frame = frame.get_color_frame()
|
||||
|
||||
if not color_frame:
|
||||
raise OSError(
|
||||
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
|
||||
)
|
||||
raise OSError(f"Can't capture color image from RealSenseCamera({self.serial_number}).")
|
||||
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
|
||||
requested_color_mode = (
|
||||
self.color_mode if temporary_color is None else temporary_color
|
||||
)
|
||||
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
|
||||
@@ -452,7 +398,7 @@ class IntelRealSenseCamera:
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.height or w != self.width:
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
raise OSError(
|
||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
@@ -469,14 +415,12 @@ class IntelRealSenseCamera:
|
||||
if self.use_depth:
|
||||
depth_frame = frame.get_depth_frame()
|
||||
if not depth_frame:
|
||||
raise OSError(
|
||||
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
|
||||
)
|
||||
raise OSError(f"Can't capture depth image from RealSenseCamera({self.serial_number}).")
|
||||
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
h, w = depth_map.shape
|
||||
if h != self.height or w != self.width:
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
raise OSError(
|
||||
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
@@ -498,8 +442,8 @@ class IntelRealSenseCamera:
|
||||
def async_read(self):
|
||||
"""Access the latest color image"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
raise DeviceNotConnectedError(
|
||||
f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.thread is None:
|
||||
@@ -513,9 +457,7 @@ class IntelRealSenseCamera:
|
||||
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
|
||||
num_tries += 1
|
||||
time.sleep(1 / self.fps)
|
||||
if num_tries > self.fps and (
|
||||
self.thread.ident is None or not self.thread.is_alive()
|
||||
):
|
||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||
raise Exception(
|
||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||
)
|
||||
@@ -527,8 +469,8 @@ class IntelRealSenseCamera:
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
raise DeviceNotConnectedError(
|
||||
f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
@@ -550,14 +492,14 @@ class IntelRealSenseCamera:
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Save a few frames using `IntelRealSenseCamera` for all cameras connected to the computer, or a selected subset."
|
||||
description="Save a few frames using `RealSenseCamera` for all cameras connected to the computer, or a selected subset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serial-numbers",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="List of serial numbers used to instantiate the `IntelRealSenseCamera`. If not provided, find and use all available camera indices.",
|
||||
help="List of serial numbers used to instantiate the `RealSenseCamera`. If not provided, find and use all available camera indices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
71
lerobot/common/cameras/intel/configuration_realsense.py
Normal file
71
lerobot/common/cameras/intel/configuration_realsense.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("intelrealsense")
|
||||
@dataclass
|
||||
class RealSenseCameraConfig(CameraConfig):
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
RealSenseCameraConfig(128422271347, 30, 640, 480)
|
||||
RealSenseCameraConfig(128422271347, 60, 640, 480)
|
||||
RealSenseCameraConfig(128422271347, 90, 640, 480)
|
||||
RealSenseCameraConfig(128422271347, 30, 1280, 720)
|
||||
RealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
|
||||
RealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
|
||||
```
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
serial_number: int | None = None
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
use_depth: bool = False
|
||||
force_hardware_reset: bool = True
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# bool is stronger than is None, since it works with empty strings
|
||||
if bool(self.name) and bool(self.serial_number):
|
||||
raise ValueError(
|
||||
f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided."
|
||||
)
|
||||
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
|
||||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
4
lerobot/common/cameras/opencv/__init__.py
Normal file
4
lerobot/common/cameras/opencv/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .camera_opencv import OpenCVCamera
|
||||
from .configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
||||
"""
|
||||
@@ -9,20 +23,21 @@ import platform
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.utils.robot_utils import (
|
||||
busy_wait,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
# The maximum opencv device index depends on your operating system. For instance,
|
||||
# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case
|
||||
# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23.
|
||||
@@ -31,14 +46,10 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
MAX_OPENCV_INDEX = 60
|
||||
|
||||
|
||||
def find_cameras(
|
||||
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
|
||||
) -> list[dict]:
|
||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
cameras = []
|
||||
if platform.system() == "Linux":
|
||||
print(
|
||||
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
|
||||
)
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
|
||||
ports = _find_cameras(possible_ports, mock=mock)
|
||||
for port in ports:
|
||||
@@ -70,7 +81,7 @@ def _find_cameras(
|
||||
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
||||
) -> list[int | str]:
|
||||
if mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -130,11 +141,12 @@ def save_images_from_cameras(
|
||||
print("Connecting cameras")
|
||||
cameras = []
|
||||
for cam_idx in camera_ids:
|
||||
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height, mock=mock)
|
||||
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
||||
f"height={camera.height}, color_mode={camera.color_mode})"
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, "
|
||||
f"height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
@@ -169,9 +181,7 @@ def save_images_from_cameras(
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
print(
|
||||
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
|
||||
)
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
@@ -181,42 +191,7 @@ def save_images_from_cameras(
|
||||
print(f"Images have been saved to {images_dir}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenCVCameraConfig:
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
OpenCVCameraConfig(30, 640, 480)
|
||||
OpenCVCameraConfig(60, 640, 480)
|
||||
OpenCVCameraConfig(90, 640, 480)
|
||||
OpenCVCameraConfig(30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
|
||||
|
||||
class OpenCVCamera:
|
||||
class OpenCVCamera(Camera):
|
||||
"""
|
||||
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
|
||||
with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
@@ -235,7 +210,10 @@ class OpenCVCamera:
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
camera = OpenCVCamera(camera_index=0)
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
|
||||
config = OpenCVCameraConfig(camera_index=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
# when done using the camera, consider disconnecting
|
||||
@@ -244,50 +222,42 @@ class OpenCVCamera:
|
||||
|
||||
Example of changing default fps, width, height and color_mode:
|
||||
```python
|
||||
camera = OpenCVCamera(0, fps=30, width=1280, height=720)
|
||||
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
|
||||
|
||||
camera = OpenCVCamera(0, fps=90, width=640, height=480)
|
||||
camera = connect()
|
||||
|
||||
camera = OpenCVCamera(0, fps=90, width=640, height=480, color_mode="bgr")
|
||||
camera = connect()
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720)
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480)
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr")
|
||||
# Note: might error out open `camera.connect()` if these settings are not compatible with the camera
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
camera_index: int | str,
|
||||
config: OpenCVCameraConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if config is None:
|
||||
config = OpenCVCameraConfig()
|
||||
|
||||
# Overwrite config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
self.camera_index = camera_index
|
||||
def __init__(self, config: OpenCVCameraConfig):
|
||||
self.config = config
|
||||
self.camera_index = config.camera_index
|
||||
self.port = None
|
||||
|
||||
# Linux uses ports for connecting to cameras
|
||||
if platform.system() == "Linux":
|
||||
if isinstance(self.camera_index, int):
|
||||
self.port = Path(f"/dev/video{self.camera_index}")
|
||||
elif isinstance(self.camera_index, str) and is_valid_unix_path(
|
||||
self.camera_index
|
||||
):
|
||||
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
|
||||
self.port = Path(self.camera_index)
|
||||
# Retrieve the camera index from a potentially symlinked path
|
||||
self.camera_index = get_camera_index_from_unix_port(self.port)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Please check the provided camera_index: {camera_index}"
|
||||
)
|
||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
self.capture_height = config.height
|
||||
|
||||
# If rotated by ±90, swap width and height.
|
||||
if config.rotation in [-90, 90]:
|
||||
self.width = config.height
|
||||
self.height = config.width
|
||||
else:
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.mock = config.mock
|
||||
@@ -300,11 +270,10 @@ class OpenCVCamera:
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# TODO(aliberts): Do we keep original width/height or do we define them after rotation?
|
||||
self.rotation = None
|
||||
if config.rotation == -90:
|
||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
@@ -315,12 +284,10 @@ class OpenCVCamera:
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is already connected."
|
||||
)
|
||||
raise DeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -328,14 +295,20 @@ class OpenCVCamera:
|
||||
# when other threads are used to save the images.
|
||||
cv2.setNumThreads(1)
|
||||
|
||||
camera_idx = (
|
||||
f"/dev/video{self.camera_index}"
|
||||
backend = (
|
||||
cv2.CAP_V4L2
|
||||
if platform.system() == "Linux"
|
||||
else self.camera_index
|
||||
else cv2.CAP_DSHOW
|
||||
if platform.system() == "Windows"
|
||||
else cv2.CAP_AVFOUNDATION
|
||||
if platform.system() == "Darwin"
|
||||
else cv2.CAP_ANY
|
||||
)
|
||||
|
||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
tmp_camera = cv2.VideoCapture(camera_idx)
|
||||
tmp_camera = cv2.VideoCapture(camera_idx, backend)
|
||||
is_camera_open = tmp_camera.isOpened()
|
||||
# Release camera to make it accessible for `find_camera_indices`
|
||||
tmp_camera.release()
|
||||
@@ -358,44 +331,41 @@ class OpenCVCamera:
|
||||
# Secondly, create the camera that will be used downstream.
|
||||
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
||||
# needs to be re-created.
|
||||
self.camera = cv2.VideoCapture(camera_idx)
|
||||
self.camera = cv2.VideoCapture(camera_idx, backend)
|
||||
|
||||
if self.fps is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||
if self.width is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
||||
if self.height is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
||||
if self.capture_width is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width)
|
||||
if self.capture_height is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height)
|
||||
|
||||
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
||||
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
|
||||
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
||||
if self.fps is not None and not math.isclose(
|
||||
self.fps, actual_fps, rel_tol=1e-3
|
||||
):
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width is not None and not math.isclose(
|
||||
self.width, actual_width, rel_tol=1e-3
|
||||
if self.capture_width is not None and not math.isclose(
|
||||
self.capture_width, actual_width, rel_tol=1e-3
|
||||
):
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||
f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.height is not None and not math.isclose(
|
||||
self.height, actual_height, rel_tol=1e-3
|
||||
if self.capture_height is not None and not math.isclose(
|
||||
self.capture_height, actual_height, rel_tol=1e-3
|
||||
):
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||
f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = round(actual_fps)
|
||||
self.width = round(actual_width)
|
||||
self.height = round(actual_height)
|
||||
|
||||
self.capture_width = round(actual_width)
|
||||
self.capture_height = round(actual_height)
|
||||
self.is_connected = True
|
||||
|
||||
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
||||
@@ -406,7 +376,7 @@ class OpenCVCamera:
|
||||
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
@@ -417,9 +387,7 @@ class OpenCVCamera:
|
||||
if not ret:
|
||||
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
||||
|
||||
requested_color_mode = (
|
||||
self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
)
|
||||
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
@@ -431,14 +399,14 @@ class OpenCVCamera:
|
||||
# so we convert the image color from BGR to RGB.
|
||||
if requested_color_mode == "rgb":
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.height or w != self.width:
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
raise OSError(
|
||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
@@ -465,7 +433,7 @@ class OpenCVCamera:
|
||||
|
||||
def async_read(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
@@ -487,7 +455,7 @@ class OpenCVCamera:
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
38
lerobot/common/cameras/opencv/configuration_opencv.py
Normal file
38
lerobot/common/cameras/opencv/configuration_opencv.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@dataclass
|
||||
class OpenCVCameraConfig(CameraConfig):
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
OpenCVCameraConfig(0, 30, 640, 480)
|
||||
OpenCVCameraConfig(0, 60, 640, 480)
|
||||
OpenCVCameraConfig(0, 90, 640, 480)
|
||||
OpenCVCameraConfig(0, 30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
camera_index: int
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
21
lerobot/common/cameras/utils.py
Normal file
21
lerobot/common/cameras/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig
|
||||
|
||||
|
||||
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
|
||||
cameras = {}
|
||||
|
||||
for key, cfg in camera_configs.items():
|
||||
if cfg.type == "opencv":
|
||||
from .opencv import OpenCVCamera
|
||||
|
||||
cameras[key] = OpenCVCamera(cfg)
|
||||
|
||||
elif cfg.type == "intelrealsense":
|
||||
from .intel.camera_realsense import RealSenseCamera
|
||||
|
||||
cameras[key] = RealSenseCamera(cfg)
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
52
lerobot/common/constants.py
Normal file
52
lerobot/common/constants.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# keys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HF_HOME
|
||||
|
||||
OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
|
||||
ROBOTS = "robots"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
# files & directories
|
||||
CHECKPOINTS_DIR = "checkpoints"
|
||||
LAST_CHECKPOINT_LINK = "last"
|
||||
PRETRAINED_MODEL_DIR = "pretrained_model"
|
||||
TRAINING_STATE_DIR = "training_state"
|
||||
RNG_STATE = "rng_state.safetensors"
|
||||
TRAINING_STEP = "training_step.json"
|
||||
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
|
||||
)
|
||||
|
||||
# cache dir
|
||||
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||
|
||||
# calibration dir
|
||||
default_calibration_path = HF_LEROBOT_HOME / ".calibration"
|
||||
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
||||
68
lerobot/common/datasets/backward_compatibility.py
Normal file
68
lerobot/common/datasets/backward_compatibility.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import packaging.version
|
||||
|
||||
V2_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
||||
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
||||
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
||||
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
||||
sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V21_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
FUTURE_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is only available in {version} format.
|
||||
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
||||
"""
|
||||
|
||||
|
||||
class CompatibilityError(Exception): ...
|
||||
|
||||
|
||||
class BackwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ForwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
@@ -13,231 +13,164 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
import numpy as np
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from lerobot.common.datasets.utils import load_image_as_numpy
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
) -> int:
|
||||
"""Heuristic to estimate the number of samples based on dataset size.
|
||||
The power controls the sample growth relative to dataset size.
|
||||
Lower the power for less number of samples.
|
||||
|
||||
Note: We assume the images are in channel first format
|
||||
For default arguments, we have:
|
||||
- from 1 to ~500, num_samples=100
|
||||
- at 1000, num_samples=177
|
||||
- at 2000, num_samples=299
|
||||
- at 5000, num_samples=594
|
||||
- at 10000, num_samples=1000
|
||||
- at 20000, num_samples=1681
|
||||
"""
|
||||
if dataset_len < min_num_samples:
|
||||
min_num_samples = dataset_len
|
||||
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
def sample_indices(data_len: int) -> list[int]:
|
||||
num_samples = estimate_num_samples(data_len)
|
||||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||
|
||||
for key in dataset.features:
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
# if isinstance(feats_type, (VideoFrame, Image)):
|
||||
if key in dataset.meta.camera_keys:
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"expect channel first images, but instead {batch[key].shape}"
|
||||
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||||
_, height, width = img.shape
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert (
|
||||
batch[key].dtype == torch.float32
|
||||
), f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert (
|
||||
batch[key].max() <= 1
|
||||
), f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert (
|
||||
batch[key].min() >= 0
|
||||
), f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
if max(width, height) < max_size_threshold:
|
||||
# no downsampling needed
|
||||
return img
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
stats_patterns[key] = "b c -> c "
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||||
return img[:, ::downsample_factor, ::downsample_factor]
|
||||
|
||||
|
||||
def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
sampled_indices = sample_indices(len(image_paths))
|
||||
|
||||
images = None
|
||||
for i, idx in enumerate(sampled_indices):
|
||||
path = image_paths[idx]
|
||||
# we load as uint8 to reduce memory usage
|
||||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||||
img = auto_downsample_height_width(img)
|
||||
|
||||
if images is None:
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||
|
||||
images[i] = img
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
||||
"std": np.std(array, axis=axis, keepdims=keepdims),
|
||||
"count": np.array([len(array)]),
|
||||
}
|
||||
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue # HACK: we should receive np.arrays of strings
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
keepdims = True
|
||||
else:
|
||||
raise ValueError(f"{key}, {batch[key].shape}")
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
|
||||
return stats_patterns
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
|
||||
|
||||
def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
|
||||
# for more info on why we need to set the same number of workers, see `load_from_videos`
|
||||
stats_patterns = get_stats_einops_patterns(dataset, num_workers)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
for key in stats_patterns:
|
||||
mean[key] = torch.tensor(0.0).float()
|
||||
std[key] = torch.tensor(0.0).float()
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(dataset, batch_size, seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(
|
||||
dataloader,
|
||||
total=ceil(max_num_samples / batch_size),
|
||||
desc="Compute mean, min, max",
|
||||
)
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = (
|
||||
mean[key]
|
||||
+ this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
)
|
||||
max[key] = torch.maximum(
|
||||
max[key], einops.reduce(batch[key], pattern, "max")
|
||||
)
|
||||
min[key] = torch.minimum(
|
||||
min[key], einops.reduce(batch[key], pattern, "min")
|
||||
)
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(
|
||||
dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std"
|
||||
)
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
for key in stats_patterns:
|
||||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = (
|
||||
std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
)
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
for key in stats_patterns:
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = {}
|
||||
for key in stats_patterns:
|
||||
stats[key] = {
|
||||
"mean": mean[key],
|
||||
"std": std[key],
|
||||
"max": max[key],
|
||||
"min": min[key],
|
||||
}
|
||||
return stats
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
for i in range(len(stats_list)):
|
||||
for fkey in stats_list[i]:
|
||||
for k, v in stats_list[i][fkey].items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
|
||||
|
||||
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregates stats for a single feature."""
|
||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||
counts = np.stack([s["count"] for s in stats_ft_list])
|
||||
total_count = counts.sum(axis=0)
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
# Prepare weighted mean by matching number of dimensions
|
||||
while counts.ndim < means.ndim:
|
||||
counts = np.expand_dims(counts, axis=-1)
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
# Compute the weighted mean
|
||||
weighted_means = means * counts
|
||||
total_mean = weighted_means.sum(axis=0) / total_count
|
||||
|
||||
# Compute the variance using the parallel algorithm
|
||||
delta_means = means - total_mean
|
||||
weighted_variances = (variances + delta_means**2) * counts
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
return {
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"mean": total_mean,
|
||||
"std": np.sqrt(total_variance),
|
||||
"count": total_count,
|
||||
}
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
The final stats will have the union of all data keys from each of the stats dicts.
|
||||
|
||||
For instance:
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_mean = (mean of all data, weighted by counts)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
data_keys = set()
|
||||
for dataset in ls_datasets:
|
||||
data_keys.update(dataset.meta.stats.keys())
|
||||
stats = {k: {} for k in data_keys}
|
||||
for data_key in data_keys:
|
||||
for stat_key in ["min", "max"]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack(
|
||||
[
|
||||
ds.meta.stats[data_key][stat_key]
|
||||
for ds in ls_datasets
|
||||
if data_key in ds.meta.stats
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(
|
||||
d.num_frames for d in ls_datasets if data_key in d.meta.stats
|
||||
)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["mean"] = sum(
|
||||
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.meta.stats
|
||||
)
|
||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||
# the computation of the mean.
|
||||
# Given two sets of data where the statistics are known:
|
||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["std"] = torch.sqrt(
|
||||
sum(
|
||||
(
|
||||
d.meta.stats[data_key]["std"] ** 2
|
||||
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
|
||||
)
|
||||
* (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.meta.stats
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
_assert_type_and_shape(stats_list)
|
||||
|
||||
data_keys = {key for stats in stats_list for key in stats}
|
||||
aggregated_stats = {key: {} for key in data_keys}
|
||||
|
||||
for key in data_keys:
|
||||
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||||
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||||
|
||||
return aggregated_stats
|
||||
|
||||
@@ -14,129 +14,105 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.transforms import ImageTransforms
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
|
||||
}
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
"""Resolves delta_timestamps config key (in-place) by using `eval`.
|
||||
def resolve_delta_timestamps(
|
||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
||||
) -> dict[str, list] | None:
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
||||
|
||||
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
|
||||
the data type of its values).
|
||||
"""
|
||||
delta_timestamps = cfg.training.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
|
||||
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
|
||||
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""
|
||||
Args:
|
||||
cfg: A Hydra config as per the LeRobot config scheme.
|
||||
split: Select the data subset used to create an instance of LeRobotDataset.
|
||||
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
|
||||
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
|
||||
slicer in the hugging face datasets:
|
||||
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
|
||||
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
|
||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||
delta_timestamps against.
|
||||
|
||||
Returns:
|
||||
The LeRobotDataset.
|
||||
"""
|
||||
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
|
||||
raise ValueError(
|
||||
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
|
||||
"strings to load multiple datasets."
|
||||
)
|
||||
|
||||
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||
if cfg.env.name != "dora":
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||
else:
|
||||
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
|
||||
|
||||
for dataset_repo_id in dataset_repo_ids:
|
||||
if cfg.env.name not in dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
|
||||
f"environment ({cfg.env.name=})."
|
||||
)
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
default_tf = OmegaConf.create(
|
||||
dict[str, list] | None: A dictionary of delta_timestamps, e.g.:
|
||||
{
|
||||
"brightness": {"weight": 0.0, "min_max": None},
|
||||
"contrast": {"weight": 0.0, "min_max": None},
|
||||
"saturation": {"weight": 0.0, "min_max": None},
|
||||
"hue": {"weight": 0.0, "min_max": None},
|
||||
"sharpness": {"weight": 0.0, "min_max": None},
|
||||
"max_num_transforms": None,
|
||||
"random_order": False,
|
||||
"image_size": None,
|
||||
"interpolation": None,
|
||||
"image_mean": None,
|
||||
"image_std": None,
|
||||
"observation.state": [-0.04, -0.02, 0]
|
||||
"observation.action": [-0.02, 0, 0.02]
|
||||
}
|
||||
)
|
||||
cfg_tf = OmegaConf.merge(
|
||||
OmegaConf.create(default_tf), cfg.training.image_transforms
|
||||
)
|
||||
returns `None` if the the resulting dict is empty.
|
||||
"""
|
||||
delta_timestamps = {}
|
||||
for key in ds_meta.features:
|
||||
if key == "next.reward" and cfg.reward_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width)
|
||||
if cfg_tf.image_size
|
||||
else None,
|
||||
interpolation=cfg_tf.interpolation,
|
||||
image_mean=cfg_tf.image_mean,
|
||||
image_std=cfg_tf.image_std,
|
||||
)
|
||||
if len(delta_timestamps) == 0:
|
||||
delta_timestamps = None
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
# TODO (aliberts): add 'episodes' arg from config after removing hydra
|
||||
return delta_timestamps
|
||||
|
||||
|
||||
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset | MultiLeRobotDataset
|
||||
"""
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
cfg.dataset.repo_id,
|
||||
# TODO(aliberts): add proper support for multi dataset
|
||||
# delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(
|
||||
stats, dtype=torch.float32
|
||||
)
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -38,22 +38,40 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
elif image_array.shape[-1] != 3:
|
||||
raise NotImplementedError(
|
||||
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||
)
|
||||
|
||||
if image_array.dtype != np.uint8:
|
||||
# Assume the image is in [0, 1] range for floating-point data
|
||||
image_array = np.clip(image_array, 0, 1)
|
||||
if range_check:
|
||||
max_ = image_array.max().item()
|
||||
min_ = image_array.min().item()
|
||||
if max_ > 1.0 or min_ < 0.0:
|
||||
raise ValueError(
|
||||
"The image data type is float, which requires values in the range [0.0, 1.0]. "
|
||||
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
|
||||
"provide a uint8 image with values in the range [0, 255]."
|
||||
)
|
||||
|
||||
image_array = (image_array * 255).astype(np.uint8)
|
||||
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
img = image_array_to_image(image)
|
||||
img = image_array_to_pil_image(image)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
img = image
|
||||
else:
|
||||
@@ -109,9 +127,7 @@ class AsyncImageWriter:
|
||||
self._stopped = False
|
||||
|
||||
if num_threads <= 0 and num_processes <= 0:
|
||||
raise ValueError(
|
||||
"Number of threads and processes must be greater than zero."
|
||||
)
|
||||
raise ValueError("Number of threads and processes must be greater than zero.")
|
||||
|
||||
if self.num_processes == 0:
|
||||
# Use threading
|
||||
@@ -125,16 +141,12 @@ class AsyncImageWriter:
|
||||
# Use multiprocessing
|
||||
self.queue = multiprocessing.JoinableQueue()
|
||||
for _ in range(self.num_processes):
|
||||
p = multiprocessing.Process(
|
||||
target=worker_process, args=(self.queue, self.num_threads)
|
||||
)
|
||||
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
|
||||
):
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -131,9 +131,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
else:
|
||||
self._delta_timestamps = None
|
||||
|
||||
def _make_data_spec(
|
||||
self, data_spec: dict[str, Any], buffer_capacity: int
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||
"""Makes the data spec for np.memmap."""
|
||||
if any(k.startswith("_") for k in data_spec):
|
||||
raise ValueError(
|
||||
@@ -156,32 +154,14 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||
# with real data rather than the dummy initialization.
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {
|
||||
"dtype": np.dtype("?"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {
|
||||
"dtype": np.dtype("float64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||
}
|
||||
for k, v in data_spec.items():
|
||||
complete_data_spec[k] = {
|
||||
"dtype": v["dtype"],
|
||||
"shape": (buffer_capacity, *v["shape"]),
|
||||
}
|
||||
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||
return complete_data_spec
|
||||
|
||||
def add_data(self, data: dict[str, np.ndarray]):
|
||||
@@ -208,9 +188,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_frames > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
next_index - 1
|
||||
]
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||
@@ -245,11 +223,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(
|
||||
np.unique(
|
||||
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
|
||||
]
|
||||
)
|
||||
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -287,9 +261,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||
)
|
||||
)[0]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
|
||||
episode_data_indices
|
||||
]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||
|
||||
for data_key in self.delta_timestamps:
|
||||
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||
@@ -306,8 +278,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Check violated query timestamps are all outside the episode range.
|
||||
assert (
|
||||
(query_ts[is_pad] < episode_timestamps[0])
|
||||
| (episode_timestamps[-1] < query_ts[is_pad])
|
||||
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||
") inside the episode range."
|
||||
@@ -322,9 +293,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||
"""Returns all data for a given data key as a Tensor."""
|
||||
return torch.from_numpy(
|
||||
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
|
||||
)
|
||||
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
|
||||
|
||||
def compute_sampler_weights(
|
||||
@@ -355,19 +324,13 @@ def compute_sampler_weights(
|
||||
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||
included here to avoid adding complexity.
|
||||
"""
|
||||
if len(offline_dataset) == 0 and (
|
||||
online_dataset is None or len(online_dataset) == 0
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of `offline_dataset` or `online_dataset` should be contain data."
|
||||
)
|
||||
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||
raise ValueError(
|
||||
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||
)
|
||||
offline_sampling_ratio = (
|
||||
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
)
|
||||
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
|
||||
weights = []
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
## Using / Updating `CODEBASE_VERSION` (for maintainers)
|
||||
|
||||
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
|
||||
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
|
||||
lerobot/common/datasets/lerobot_dataset.py) variable.
|
||||
|
||||
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
|
||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
|
||||
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
|
||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||
|
||||
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
|
||||
`info.json` metadata.
|
||||
|
||||
### Uploading a new dataset
|
||||
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
|
||||
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
|
||||
dataset with the current `CODEBASE_VERSION`.
|
||||
|
||||
### Updating an existing dataset
|
||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
||||
codebase won't be affected by your change and backward compatibility is maintained.
|
||||
|
||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
|
||||
api = HfApi()
|
||||
|
||||
for repo_id in available_datasets:
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if CODEBASE_VERSION in branches:
|
||||
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
|
||||
continue
|
||||
else:
|
||||
# Now create a branch named after the new version by branching out from "main"
|
||||
# which is expected to be the preceding version
|
||||
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
|
||||
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
|
||||
```
|
||||
@@ -37,16 +37,10 @@ def check_chunks_compatible(chunks: tuple, shape: tuple):
|
||||
assert c > 0
|
||||
|
||||
|
||||
def rechunk_recompress_array(
|
||||
group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"
|
||||
):
|
||||
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
|
||||
old_arr = group[name]
|
||||
if chunks is None:
|
||||
chunks = (
|
||||
(chunk_length,) + old_arr.chunks[1:]
|
||||
if chunk_length is not None
|
||||
else old_arr.chunks
|
||||
)
|
||||
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
|
||||
check_chunks_compatible(chunks, old_arr.shape)
|
||||
|
||||
if compressor is None:
|
||||
@@ -88,18 +82,13 @@ def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=No
|
||||
for i in range(len(shape) - 1):
|
||||
this_chunk_bytes = itemsize * np.prod(rshape[:i])
|
||||
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
|
||||
if (
|
||||
this_chunk_bytes <= target_chunk_bytes
|
||||
and next_chunk_bytes > target_chunk_bytes
|
||||
):
|
||||
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
|
||||
split_idx = i
|
||||
|
||||
rchunks = rshape[:split_idx]
|
||||
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
|
||||
this_max_chunk_length = rshape[split_idx]
|
||||
next_chunk_length = min(
|
||||
this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes)
|
||||
)
|
||||
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
|
||||
rchunks.append(next_chunk_length)
|
||||
len_diff = len(shape) - len(rchunks)
|
||||
rchunks.extend([1] * len_diff)
|
||||
@@ -135,13 +124,7 @@ class ReplayBuffer:
|
||||
root.require_group("data", overwrite=False)
|
||||
meta = root.require_group("meta", overwrite=False)
|
||||
if "episode_ends" not in meta:
|
||||
meta.zeros(
|
||||
"episode_ends",
|
||||
shape=(0,),
|
||||
dtype=np.int64,
|
||||
compressor=None,
|
||||
overwrite=False,
|
||||
)
|
||||
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
|
||||
return cls(root=root)
|
||||
|
||||
@classmethod
|
||||
@@ -210,11 +193,7 @@ class ReplayBuffer:
|
||||
root = zarr.group(store=store)
|
||||
# copy without recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=src_store,
|
||||
dest=store,
|
||||
source_path="/meta",
|
||||
dest_path="/meta",
|
||||
if_exists=if_exists,
|
||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||
)
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
if keys is None:
|
||||
@@ -222,9 +201,7 @@ class ReplayBuffer:
|
||||
for key in keys:
|
||||
value = src_root["data"][key]
|
||||
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = cls._resolve_array_compressor(
|
||||
compressors=compressors, key=key, array=value
|
||||
)
|
||||
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
@@ -309,17 +286,13 @@ class ReplayBuffer:
|
||||
meta_group = root.create_group("meta", overwrite=True)
|
||||
# save meta, no chunking
|
||||
for key, value in self.root["meta"].items():
|
||||
_ = meta_group.array(
|
||||
name=key, data=value, shape=value.shape, chunks=value.shape
|
||||
)
|
||||
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
|
||||
|
||||
# save data, chunk
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
for key, value in self.root["data"].items():
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(
|
||||
compressors=compressors, key=key, array=value
|
||||
)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
if isinstance(value, zarr.Array):
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
@@ -366,19 +339,13 @@ class ReplayBuffer:
|
||||
@staticmethod
|
||||
def resolve_compressor(compressor="default"):
|
||||
if compressor == "default":
|
||||
compressor = numcodecs.Blosc(
|
||||
cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE
|
||||
)
|
||||
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
|
||||
elif compressor == "disk":
|
||||
compressor = numcodecs.Blosc(
|
||||
"zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE
|
||||
)
|
||||
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
|
||||
return compressor
|
||||
|
||||
@classmethod
|
||||
def _resolve_array_compressor(
|
||||
cls, compressors: dict | str | numcodecs.abc.Codec, key, array
|
||||
):
|
||||
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
|
||||
# allows compressor to be explicitly set to None
|
||||
cpr = "nil"
|
||||
if isinstance(compressors, dict):
|
||||
@@ -437,11 +404,7 @@ class ReplayBuffer:
|
||||
if self.backend == "zarr":
|
||||
for key, value in np_data.items():
|
||||
_ = meta_group.array(
|
||||
name=key,
|
||||
data=value,
|
||||
shape=value.shape,
|
||||
chunks=value.shape,
|
||||
overwrite=True,
|
||||
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
|
||||
)
|
||||
else:
|
||||
meta_group.update(np_data)
|
||||
@@ -551,18 +514,10 @@ class ReplayBuffer:
|
||||
# create array
|
||||
if key not in self.data:
|
||||
if is_zarr:
|
||||
cks = self._resolve_array_chunks(
|
||||
chunks=chunks, key=key, array=value
|
||||
)
|
||||
cpr = self._resolve_array_compressor(
|
||||
compressors=compressors, key=key, array=value
|
||||
)
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
arr = self.data.zeros(
|
||||
name=key,
|
||||
shape=new_shape,
|
||||
chunks=cks,
|
||||
dtype=value.dtype,
|
||||
compressor=cpr,
|
||||
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
|
||||
)
|
||||
else:
|
||||
# copy data to prevent modify
|
||||
@@ -589,9 +544,7 @@ class ReplayBuffer:
|
||||
|
||||
# rechunk
|
||||
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
|
||||
rechunk_recompress_array(
|
||||
self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5)
|
||||
)
|
||||
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
|
||||
|
||||
def drop_episode(self):
|
||||
is_zarr = self.backend == "zarr"
|
||||
|
||||
@@ -152,7 +152,7 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
# Send warning if raw_dir isn't well formated
|
||||
# Send warning if raw_dir isn't well formatted
|
||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
||||
|
||||
@@ -38,9 +38,7 @@ import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import (
|
||||
AVAILABLE_RAW_REPO_IDS,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||
|
||||
@@ -75,9 +73,7 @@ def encode_datasets(
|
||||
check_repo_id(raw_repo_id)
|
||||
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
|
||||
dataset_raw_dir = raw_dir / raw_repo_id
|
||||
dataset_dir = (
|
||||
local_dir / dataset_repo_id_push if local_dir is not None else None
|
||||
)
|
||||
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
|
||||
encoding = {
|
||||
"vcodec": vcodec,
|
||||
"pix_fmt": pix_fmt,
|
||||
|
||||
@@ -133,9 +133,7 @@ class Jpeg2k(Codec):
|
||||
)
|
||||
|
||||
def decode(self, buf, out=None):
|
||||
return imagecodecs.jpeg2k_decode(
|
||||
buf, verbose=self.verbose, numthreads=self.numthreads, out=out
|
||||
)
|
||||
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
|
||||
|
||||
|
||||
class JpegXl(Codec):
|
||||
|
||||
@@ -44,9 +44,7 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
def get_cameras(hdf5_data):
|
||||
# ignore depth channel, not currently handled
|
||||
# TODO(rcadene): add depth
|
||||
rgb_cameras = [
|
||||
key for key in hdf5_data["/observations/images"].keys() if "depth" not in key
|
||||
] # noqa: SIM118
|
||||
rgb_cameras = [key for key in hdf5_data["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||
return rgb_cameras
|
||||
|
||||
|
||||
@@ -75,9 +73,7 @@ def check_format(raw_dir) -> bool:
|
||||
else:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -138,17 +134,14 @@ def load_from_raw(
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(
|
||||
tmp_imgs_dir, video_path, fps, **(encoding or {})
|
||||
)
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
@@ -188,18 +181,15 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
|
||||
@@ -26,9 +26,7 @@ import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
@@ -44,19 +42,11 @@ def check_format(raw_dir) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
raise ValueError(
|
||||
f"Missing reference files for camera, starting with in '{raw_dir}'"
|
||||
)
|
||||
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
|
||||
# select first camera in alphanumeric order
|
||||
reference_key = sorted(reference_files)[0].stem
|
||||
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
|
||||
@@ -78,11 +68,11 @@ def load_from_raw(
|
||||
modality_df,
|
||||
on="timestamp_utc",
|
||||
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
||||
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# matching timestamps that are too far apart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
||||
# are too far appart.
|
||||
# are too far apart.
|
||||
direction="nearest",
|
||||
tolerance=pd.Timedelta(f"{1/fps} seconds"),
|
||||
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
|
||||
)
|
||||
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
|
||||
df = df[df["episode_index"] != -1]
|
||||
@@ -117,9 +107,7 @@ def load_from_raw(
|
||||
|
||||
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
|
||||
# each episode starts with timestamp 0 to match the ones from the video
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(
|
||||
lambda x: x - x.iloc[0]
|
||||
)
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
|
||||
|
||||
del df["timestamp_utc"]
|
||||
|
||||
@@ -132,15 +120,13 @@ def load_from_raw(
|
||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||
if ep_ids != expected_ep_ids:
|
||||
raise ValueError(
|
||||
f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
|
||||
)
|
||||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||
|
||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formated
|
||||
# sanity check the video paths are well formatted
|
||||
for key in df:
|
||||
if "observation.images." not in key:
|
||||
continue
|
||||
@@ -157,7 +143,7 @@ def load_from_raw(
|
||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
||||
|
||||
# sanity check the video path is well formated
|
||||
# sanity check the video path is well formatted
|
||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video file not found in {video_path}")
|
||||
@@ -166,9 +152,7 @@ def load_from_raw(
|
||||
data_dict[key] = torch.from_numpy(df[key].values)
|
||||
# is vector
|
||||
elif df[key].iloc[0].shape[0] > 1:
|
||||
data_dict[key] = torch.stack(
|
||||
[torch.from_numpy(x.copy()) for x in df[key].values]
|
||||
)
|
||||
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
@@ -186,18 +170,15 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
For all datasets in the RLDS format.
|
||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||
|
||||
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
|
||||
NOTE: You need to install tensorflow and tensorflow_datasets before running this script.
|
||||
|
||||
Example:
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
@@ -143,11 +143,7 @@ def load_from_raw(
|
||||
else:
|
||||
state_keys.append(key)
|
||||
|
||||
lang_key = (
|
||||
"language_instruction"
|
||||
if "language_instruction" in dataset.element_spec
|
||||
else None
|
||||
)
|
||||
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
||||
|
||||
print(" - image_keys: ", image_keys)
|
||||
print(" - lang_key: ", lang_key)
|
||||
@@ -206,9 +202,7 @@ def load_from_raw(
|
||||
|
||||
# If lang_key is present, convert the entire tensor at once
|
||||
if lang_key is not None:
|
||||
ep_dict["language_instruction"] = [
|
||||
x.numpy().decode("utf-8") for x in episode[lang_key]
|
||||
]
|
||||
ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
|
||||
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
@@ -240,8 +234,7 @@ def load_from_raw(
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
@@ -266,9 +259,7 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
for key in data_dict:
|
||||
# check if vector state obs
|
||||
if key.startswith("observation.") and "observation.images." not in key:
|
||||
features[key] = Sequence(
|
||||
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||
# check if image obs
|
||||
elif "observation.images." in key:
|
||||
if video:
|
||||
|
||||
@@ -56,9 +56,7 @@ def check_format(raw_dir):
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
|
||||
assert all(
|
||||
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
|
||||
)
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -78,9 +76,7 @@ def load_from_raw(
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
|
||||
)
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
# as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
|
||||
success_threshold = 0.95 # 95% coverage,
|
||||
@@ -154,9 +150,7 @@ def load_from_raw(
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(
|
||||
space, block_pos[i].tolist(), block_angle[i].item()
|
||||
)
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
@@ -165,9 +159,7 @@ def load_from_raw(
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
if keypoints_instead_of_image:
|
||||
keypoints[i] = torch.from_numpy(
|
||||
PushTEnv.get_keypoints(block_shapes).flatten()
|
||||
)
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
@@ -192,8 +184,7 @@ def load_from_raw(
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
@@ -202,9 +193,7 @@ def load_from_raw(
|
||||
if keypoints_instead_of_image:
|
||||
ep_dict["observation.environment_state"] = keypoints
|
||||
ep_dict["action"] = actions[from_idx:to_idx]
|
||||
ep_dict["episode_index"] = torch.tensor(
|
||||
[ep_idx] * num_frames, dtype=torch.int64
|
||||
)
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
# ep_dict["next.observation.image"] = image[1:],
|
||||
@@ -231,8 +220,7 @@ def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if keypoints_instead_of_image:
|
||||
features["observation.environment_state"] = Sequence(
|
||||
@@ -273,9 +261,7 @@ def from_raw_to_lerobot_format(
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict = load_from_raw(
|
||||
raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding
|
||||
)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
|
||||
@@ -26,9 +26,7 @@ from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import (
|
||||
register_codecs,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
@@ -63,9 +61,7 @@ def check_format(raw_dir) -> bool:
|
||||
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
assert all(
|
||||
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
|
||||
)
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -83,9 +79,7 @@ def load_from_raw(
|
||||
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
||||
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
||||
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
||||
eff_rot_axis_angle = torch.from_numpy(
|
||||
zarr_data["data/robot0_eef_rot_axis_angle"][:]
|
||||
)
|
||||
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
||||
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
||||
|
||||
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
||||
@@ -135,31 +129,24 @@ def load_from_raw(
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
encode_video_frames(
|
||||
tmp_imgs_dir, video_path, fps, **(encoding or {})
|
||||
)
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["episode_index"] = torch.tensor(
|
||||
[ep_idx] * num_frames, dtype=torch.int64
|
||||
)
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor(
|
||||
[from_idx + num_frames] * num_frames
|
||||
)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||
@@ -185,8 +172,7 @@ def to_hf_dataset(data_dict, video):
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
@@ -206,8 +192,7 @@ def to_hf_dataset(data_dict, video):
|
||||
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["gripper_width"] = Sequence(
|
||||
length=data_dict["gripper_width"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
|
||||
@@ -45,9 +45,7 @@ def concatenate_episodes(ep_dicts):
|
||||
return data_dict
|
||||
|
||||
|
||||
def save_images_concurrently(
|
||||
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
|
||||
):
|
||||
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -57,10 +55,7 @@ def save_images_concurrently(
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[
|
||||
executor.submit(save_image, imgs_array[i], i, out_dir)
|
||||
for i in range(num_images)
|
||||
]
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
@@ -69,8 +64,7 @@ def get_default_encoding() -> dict:
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
if v.default is not inspect.Parameter.empty
|
||||
and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
}
|
||||
|
||||
|
||||
@@ -83,9 +77,7 @@ def check_repo_id(repo_id: str) -> None:
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(
|
||||
hf_dataset: datasets.Dataset,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
|
||||
@@ -40,10 +40,7 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
def check_format(raw_dir):
|
||||
keys = {"actions", "rewards", "dones"}
|
||||
nested_keys = {
|
||||
"observations": {"rgb", "state"},
|
||||
"next_observations": {"rgb", "state"},
|
||||
}
|
||||
nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
|
||||
|
||||
xarm_files = list(raw_dir.glob("*.pkl"))
|
||||
assert len(xarm_files) > 0
|
||||
@@ -56,17 +53,11 @@ def check_format(raw_dir):
|
||||
|
||||
# Check for consistent lengths in nested keys
|
||||
expected_len = len(dataset_dict["actions"])
|
||||
assert all(
|
||||
len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict
|
||||
)
|
||||
assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict)
|
||||
|
||||
for key, subkeys in nested_keys.items():
|
||||
nested_dict = dataset_dict.get(key, {})
|
||||
assert all(
|
||||
len(nested_dict[subkey]) == expected_len
|
||||
for subkey in subkeys
|
||||
if subkey in nested_dict
|
||||
)
|
||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -131,18 +122,13 @@ def load_from_raw(
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
]
|
||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor(
|
||||
[ep_idx] * num_frames, dtype=torch.int64
|
||||
)
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
# ep_dict["next.observation.image"] = next_image
|
||||
@@ -167,8 +153,7 @@ def to_hf_dataset(data_dict, video):
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
|
||||
@@ -43,10 +43,7 @@ class EpisodeAwareSampler:
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
indices.extend(
|
||||
range(
|
||||
start_index.item() + drop_n_first_frames,
|
||||
end_index.item() - drop_n_last_frames,
|
||||
)
|
||||
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
@@ -57,9 +58,7 @@ class RandomSubsetApply(Transform):
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (1 <= n_subset <= len(transforms)):
|
||||
raise ValueError(
|
||||
f"n_subset should be in the interval [1, {len(transforms)}]"
|
||||
)
|
||||
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
@@ -67,6 +66,8 @@ class RandomSubsetApply(Transform):
|
||||
self.n_subset = n_subset
|
||||
self.random_order = random_order
|
||||
|
||||
self.selected_transforms = None
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
needs_unpacking = len(inputs) > 1
|
||||
|
||||
@@ -74,9 +75,9 @@ class RandomSubsetApply(Transform):
|
||||
if not self.random_order:
|
||||
selected_indices = selected_indices.sort().values
|
||||
|
||||
selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
self.selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
|
||||
for transform in selected_transforms:
|
||||
for transform in self.selected_transforms:
|
||||
outputs = transform(*inputs)
|
||||
inputs = outputs if needs_unpacking else (outputs,)
|
||||
|
||||
@@ -118,121 +119,131 @@ class SharpnessJitter(Transform):
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError(
|
||||
"If sharpness is a single number, it must be non negative."
|
||||
)
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
sharpness[0] = max(sharpness[0], 0.0)
|
||||
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||
sharpness = [float(v) for v in sharpness]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{sharpness=} should be a single number or a sequence with length 2."
|
||||
)
|
||||
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||
|
||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||
raise ValueError(
|
||||
f"sharpnesss values should be between (0., inf), but got {sharpness}."
|
||||
)
|
||||
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def _generate_value(self, left: float, right: float) -> float:
|
||||
return torch.empty(1).uniform_(left, right).item()
|
||||
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
||||
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
return {"sharpness_factor": sharpness_factor}
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||
return self._call_kernel(
|
||||
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
|
||||
)
|
||||
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
||||
sharpness_factor = params["sharpness_factor"]
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def get_image_transforms(
|
||||
brightness_weight: float = 1.0,
|
||||
brightness_min_max: tuple[float, float] | None = None,
|
||||
contrast_weight: float = 1.0,
|
||||
contrast_min_max: tuple[float, float] | None = None,
|
||||
saturation_weight: float = 1.0,
|
||||
saturation_min_max: tuple[float, float] | None = None,
|
||||
hue_weight: float = 1.0,
|
||||
hue_min_max: tuple[float, float] | None = None,
|
||||
sharpness_weight: float = 1.0,
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
interpolation: str | None = None,
|
||||
image_size: tuple[int, int] | None = None,
|
||||
image_mean: list[float] | None = None,
|
||||
image_std: list[float] | None = None,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(
|
||||
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
|
||||
)
|
||||
if weight < 0.0:
|
||||
raise ValueError(
|
||||
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
|
||||
)
|
||||
@dataclass
|
||||
class ImageTransformConfig:
|
||||
"""
|
||||
For each transform, the following parameters are available:
|
||||
weight: This represents the multinomial probability (with no replacement)
|
||||
used for sampling the transform. If the sum of the weights is not 1,
|
||||
they will be normalized.
|
||||
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
|
||||
custom transform defined here.
|
||||
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
|
||||
(following uniform distribution) when it's applied.
|
||||
"""
|
||||
|
||||
check_value("brightness", brightness_weight, brightness_min_max)
|
||||
check_value("contrast", contrast_weight, contrast_min_max)
|
||||
check_value("saturation", saturation_weight, saturation_min_max)
|
||||
check_value("hue", hue_weight, hue_min_max)
|
||||
check_value("sharpness", sharpness_weight, sharpness_min_max)
|
||||
weight: float = 1.0
|
||||
type: str = "Identity"
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if image_size is not None:
|
||||
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
|
||||
if interpolation is None:
|
||||
# Use BICUBIC as default interpolation
|
||||
interpolation_mode = v2.InterpolationMode.BICUBIC
|
||||
elif interpolation in interpolations:
|
||||
interpolation_mode = v2.InterpolationMode(interpolation)
|
||||
else:
|
||||
raise ValueError("The interpolation passed is not supported")
|
||||
# Weight for resizing is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(
|
||||
v2.Resize(
|
||||
size=(image_size[0], image_size[1]), interpolation=interpolation_mode
|
||||
)
|
||||
)
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None and contrast_weight > 0.0:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None and saturation_weight > 0.0:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None and hue_weight > 0.0:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
if image_mean is not None and image_std is not None:
|
||||
# Weight for normalization is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(
|
||||
v2.Normalize(
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
)
|
||||
)
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
n_subset = min(n_subset, max_num_transforms)
|
||||
@dataclass
|
||||
class ImageTransformsConfig:
|
||||
"""
|
||||
These transforms are all using standard torchvision.transforms.v2
|
||||
You can find out how these transformations affect images here:
|
||||
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||
We use a custom RandomSubsetApply container to sample them.
|
||||
"""
|
||||
|
||||
if n_subset == 0:
|
||||
return v2.Identity()
|
||||
# Set this flag to `true` to enable transforms during training
|
||||
enable: bool = False
|
||||
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
||||
# It's an integer in the interval [1, number_of_available_transforms].
|
||||
max_num_transforms: int = 3
|
||||
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||
# Set this to True to apply them in a random order.
|
||||
random_order: bool = False
|
||||
tfs: dict[str, ImageTransformConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.8, 1.2)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.8, 1.2)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 1.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (-0.05, 0.05)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 1.5)},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
else:
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(
|
||||
transforms, p=weights, n_subset=n_subset, random_order=random_order
|
||||
)
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
class ImageTransforms(Transform):
|
||||
"""A class to compose image transforms based on configuration."""
|
||||
|
||||
def __init__(self, cfg: ImageTransformsConfig) -> None:
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
|
||||
self.weights = []
|
||||
self.transforms = {}
|
||||
for tf_name, tf_cfg in cfg.tfs.items():
|
||||
if tf_cfg.weight <= 0.0:
|
||||
continue
|
||||
|
||||
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
|
||||
self.weights.append(tf_cfg.weight)
|
||||
|
||||
n_subset = min(len(self.transforms), cfg.max_num_transforms)
|
||||
if n_subset == 0 or not cfg.enable:
|
||||
self.tf = v2.Identity()
|
||||
else:
|
||||
self.tf = RandomSubsetApply(
|
||||
transforms=list(self.transforms.values()),
|
||||
p=self.weights,
|
||||
n_subset=n_subset,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
return self.tf(*inputs)
|
||||
|
||||
@@ -13,10 +13,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
@@ -27,31 +27,34 @@ from typing import Any
|
||||
import datasets
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import pyarrow.compute as pc
|
||||
import packaging.version
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.datasets.backward_compatibility import (
|
||||
V21_MESSAGE,
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.common.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = (
|
||||
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
)
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
)
|
||||
DEFAULT_IMAGE_PATH = (
|
||||
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
)
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
@@ -104,20 +107,39 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
return outdict
|
||||
|
||||
|
||||
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||
split_keys = flattened_key.split(sep)
|
||||
getter = obj[split_keys[0]]
|
||||
if len(split_keys) == 1:
|
||||
return getter
|
||||
|
||||
for key in split_keys[1:]:
|
||||
getter = getter[key]
|
||||
|
||||
return getter
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {
|
||||
key: value.tolist() for key, value in flatten_dict(stats).items()
|
||||
}
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
|
||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
dataset.to_parquet(fpath)
|
||||
return dataset
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
@@ -148,6 +170,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path):
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
@@ -155,34 +181,76 @@ def load_info(local_dir: Path) -> dict:
|
||||
return info
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
def write_stats(stats: dict, local_dir: Path):
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
return {
|
||||
item["task_index"]: item["task"]
|
||||
for item in sorted(tasks, key=lambda x: x["task_index"])
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
return cast_stats_to_numpy(stats)
|
||||
|
||||
|
||||
def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
|
||||
def write_episode(episode: dict, local_dir: Path):
|
||||
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||
|
||||
|
||||
def load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||
}
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
return {ep_idx: stats for ep_idx in episodes}
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype="float32", channel_first: bool = True
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if "float" in dtype:
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
@@ -201,80 +269,95 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
def _get_major_minor(version: str) -> tuple[int]:
|
||||
split = version.strip("v").split(".")
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id, version):
|
||||
message = textwrap.dedent(f"""
|
||||
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
|
||||
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
|
||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
""")
|
||||
super().__init__(message)
|
||||
def is_valid_version(version: str) -> bool:
|
||||
try:
|
||||
packaging.version.parse(version)
|
||||
return True
|
||||
except packaging.version.InvalidVersion:
|
||||
return False
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str,
|
||||
version_to_check: str,
|
||||
current_version: str,
|
||||
version_to_check: str | packaging.version.Version,
|
||||
current_version: str | packaging.version.Version,
|
||||
enforce_breaking_major: bool = True,
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
logging.warning(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
codebase. The current codebase version is {current_version}. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
)
|
||||
v_check = (
|
||||
packaging.version.parse(version_to_check)
|
||||
if not isinstance(version_to_check, packaging.version.Version)
|
||||
else version_to_check
|
||||
)
|
||||
v_current = (
|
||||
packaging.version.parse(current_version)
|
||||
if not isinstance(current_version, packaging.version.Version)
|
||||
else current_version
|
||||
)
|
||||
if v_check.major < v_current.major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, v_check)
|
||||
elif v_check.minor < v_current.minor:
|
||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
"""Returns available valid versions (branches and tags) on given repo."""
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
num_version = float(version.strip("v"))
|
||||
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
|
||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
||||
raise BackwardCompatibilityError(repo_id, version)
|
||||
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
||||
repo_versions = []
|
||||
for ref in repo_refs:
|
||||
with contextlib.suppress(packaging.version.InvalidVersion):
|
||||
repo_versions.append(packaging.version.parse(ref))
|
||||
|
||||
logging.warning(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
The requested version ('{version}') is not found. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
return repo_versions
|
||||
|
||||
|
||||
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
|
||||
"""
|
||||
Returns the version if available on repo or the latest compatible one.
|
||||
Otherwise, will throw a `CompatibilityError`.
|
||||
"""
|
||||
target_version = (
|
||||
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||
)
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
if not hub_versions:
|
||||
raise RevisionNotFoundError(
|
||||
f"""Your dataset must be tagged with a codebase version.
|
||||
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||
```
|
||||
"""
|
||||
)
|
||||
if "main" not in branches:
|
||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||
return "main"
|
||||
else:
|
||||
return version
|
||||
|
||||
if target_version in hub_versions:
|
||||
return f"v{target_version}"
|
||||
|
||||
compatibles = [
|
||||
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
||||
]
|
||||
if compatibles:
|
||||
return_version = max(compatibles)
|
||||
if return_version < target_version:
|
||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||
return f"v{return_version}"
|
||||
|
||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||
if lower_major:
|
||||
raise BackwardCompatibilityError(repo_id, max(lower_major))
|
||||
|
||||
upper_versions = [v for v in hub_versions if v > target_version]
|
||||
assert len(upper_versions) > 0
|
||||
raise ForwardCompatibilityError(repo_id, min(upper_versions))
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
@@ -286,12 +369,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
else:
|
||||
assert len(ft["shape"]) == 1
|
||||
elif len(ft["shape"]) == 1:
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
# TODO: (alibers, azouitine) Add support for ft["shap"] == 0 as Value
|
||||
elif len(ft["shape"]) == 2:
|
||||
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 3:
|
||||
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 4:
|
||||
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 5:
|
||||
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
else:
|
||||
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
@@ -306,6 +397,37 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
||||
|
||||
|
||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
||||
policy_features = {}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
if ft["dtype"] in ["image", "video"]:
|
||||
type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
type = FeatureType.STATE
|
||||
elif key == "action":
|
||||
type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
|
||||
policy_features[key] = PolicyFeature(
|
||||
type=type,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
return policy_features
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
@@ -331,94 +453,85 @@ def create_empty_dataset_info(
|
||||
|
||||
|
||||
def get_episode_data_index(
|
||||
episode_dicts: list[dict], episodes: list[int] | None = None
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {
|
||||
ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)
|
||||
}
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||
cumulative_lengths = list(accumulate(episode_lengths.values()))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def calculate_total_episode(
|
||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
||||
return total_episodes
|
||||
|
||||
|
||||
def calculate_episode_data_index(
|
||||
hf_dataset: datasets.Dataset,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lengths),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
timestamps: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
episode_data_index: dict[str, np.ndarray],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
|
||||
account for possible numerical error.
|
||||
"""
|
||||
timestamps = torch.stack(hf_dataset["timestamp"])
|
||||
diffs = torch.diff(timestamps)
|
||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
||||
to account for possible numerical error.
|
||||
|
||||
# We mask differences between the timestamp at the end of an episode
|
||||
# and the one at the start of the next episode since these are expected
|
||||
# to be outside tolerance.
|
||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||
Args:
|
||||
timestamps (np.ndarray): Array of timestamps in seconds.
|
||||
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
||||
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
||||
which identifies indices for the end of each episode.
|
||||
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
||||
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
||||
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
||||
|
||||
Returns:
|
||||
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If the check fails and `raise_value_error` is True.
|
||||
"""
|
||||
if timestamps.shape != episode_indices.shape:
|
||||
raise ValueError(
|
||||
"timestamps and episode_indices should have the same shape. "
|
||||
f"Found {timestamps.shape=} and {episode_indices.shape=}."
|
||||
)
|
||||
|
||||
# Consecutive differences
|
||||
diffs = np.diff(timestamps)
|
||||
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
|
||||
|
||||
# Mask to ignore differences at the boundaries between episodes
|
||||
mask = np.ones(len(diffs), dtype=bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
if not torch.all(filtered_within_tolerance):
|
||||
# Check if all remaining diffs are within tolerance
|
||||
if not np.all(filtered_within_tolerance):
|
||||
# Track original indices before masking
|
||||
original_indices = torch.arange(len(diffs))
|
||||
original_indices = np.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = torch.nonzero(
|
||||
~filtered_within_tolerance
|
||||
) # .squeeze()
|
||||
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||
|
||||
outside_tolerances = []
|
||||
for idx in outside_tolerance_indices:
|
||||
entry = {
|
||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||
"diff": diffs[idx],
|
||||
"episode_index": episode_indices[idx].item(),
|
||||
"episode_index": episode_indices[idx].item()
|
||||
if hasattr(episode_indices[idx], "item")
|
||||
else episode_indices[idx],
|
||||
}
|
||||
outside_tolerances.append(entry)
|
||||
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||
This might be due to synchronization issues with timestamps during data collection.
|
||||
This might be due to synchronization issues during data collection.
|
||||
\n{pformat(outside_tolerances)}"""
|
||||
)
|
||||
return False
|
||||
@@ -427,10 +540,7 @@ def check_timestamps_sync(
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
||||
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
||||
@@ -438,14 +548,10 @@ def check_delta_timestamps(
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [
|
||||
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
|
||||
]
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts
|
||||
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
|
||||
if not is_within
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
]
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
@@ -463,12 +569,10 @@ def check_delta_timestamps(
|
||||
return True
|
||||
|
||||
|
||||
def get_delta_indices(
|
||||
delta_timestamps: dict[str, list[float]], fps: int
|
||||
) -> dict[str, list[int]]:
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
|
||||
return delta_indices
|
||||
|
||||
@@ -530,9 +634,7 @@ def create_lerobot_dataset_card(
|
||||
],
|
||||
)
|
||||
|
||||
card_template = (
|
||||
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
|
||||
).read_text()
|
||||
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
|
||||
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
@@ -594,3 +696,118 @@ class IterableNamespace(SimpleNamespace):
|
||||
|
||||
def keys(self):
|
||||
return vars(self).keys()
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict):
|
||||
optional_features = {"timestamp"}
|
||||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
|
||||
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += validate_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
for name in common_features - {"task"}:
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def validate_features_presence(
|
||||
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
||||
):
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - (expected_features | optional_features)
|
||||
|
||||
if missing_features or extra_features:
|
||||
error_message += "Feature mismatch in `frame` dictionary:\n"
|
||||
if missing_features:
|
||||
error_message += f"Missing features: {missing_features}\n"
|
||||
if extra_features:
|
||||
error_message += f"Extra features: {extra_features}\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
actual_shape = value.shape
|
||||
|
||||
if actual_dtype != np.dtype(expected_dtype):
|
||||
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
||||
|
||||
if actual_shape != expected_shape:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str):
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
if "task" not in episode_buffer:
|
||||
raise ValueError("task key not found in episode_buffer")
|
||||
|
||||
if episode_buffer["episode_index"] != total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `features`."
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
@@ -26,16 +26,14 @@ from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import (
|
||||
convert_dataset,
|
||||
parse_robot_config,
|
||||
)
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
|
||||
from lerobot.common.robots.aloha.configuration_aloha import AlohaRobotConfig
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
|
||||
# spellchecker:off
|
||||
ALOHA_MOBILE_INFO = {
|
||||
"robot_config": parse_robot_config(ALOHA_CONFIG),
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
"url": "https://mobile-aloha.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.02117",
|
||||
@@ -48,7 +46,7 @@ ALOHA_MOBILE_INFO = {
|
||||
}""").lstrip(),
|
||||
}
|
||||
ALOHA_STATIC_INFO = {
|
||||
"robot_config": parse_robot_config(ALOHA_CONFIG),
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
"url": "https://tonyzhaozh.github.io/aloha/",
|
||||
"paper": "https://arxiv.org/abs/2304.13705",
|
||||
@@ -120,10 +118,7 @@ DATASETS = {
|
||||
"single_task": "Place the battery into the slot of the remote controller.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_candy": {
|
||||
"single_task": "Pick up the candy and unwrap it.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
|
||||
"aloha_static_coffee": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
@@ -172,22 +167,13 @@ DATASETS = {
|
||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {
|
||||
"single_task": "Slide open the ziploc bag.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_scripted": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_human_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
@@ -208,19 +194,10 @@ DATASETS = {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"pusht": {
|
||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||
**PUSHT_INFO,
|
||||
},
|
||||
"pusht_image": {
|
||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||
**PUSHT_INFO,
|
||||
},
|
||||
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
||||
"unitreeh1_rearrange_objects": {
|
||||
"single_task": "Put the object into the bin.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
|
||||
"unitreeh1_two_robot_greeting": {
|
||||
"single_task": "Greet the other robot with a high five.",
|
||||
**UNITREEH_INFO,
|
||||
@@ -230,31 +207,13 @@ DATASETS = {
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_image": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_replay": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_replay_image": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_image": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_replay": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_replay_image": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"umi_cup_in_the_wild": {
|
||||
"single_task": "Put the cup on the plate.",
|
||||
"license": "apache-2.0",
|
||||
@@ -898,6 +857,7 @@ DATASETS = {
|
||||
}""").lstrip(),
|
||||
},
|
||||
}
|
||||
# spellchecker:on
|
||||
|
||||
|
||||
def batch_convert():
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
|
||||
|
||||
We support 3 different scenarios for these tasks (see instructions below):
|
||||
1. Single task dataset: all episodes of your dataset have the same single task.
|
||||
@@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_hub_safe_version,
|
||||
get_safe_version,
|
||||
load_json,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
@@ -141,7 +141,8 @@ from lerobot.common.datasets.video_utils import (
|
||||
get_image_pixel_channels,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.common.robots import RobotConfig
|
||||
from lerobot.common.robots.utils import make_robot_config
|
||||
|
||||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
@@ -152,21 +153,18 @@ V1_INFO_PATH = "meta_data/info.json"
|
||||
V1_STATS_PATH = "meta_data/stats.safetensors"
|
||||
|
||||
|
||||
def parse_robot_config(
|
||||
config_path: Path, config_overrides: list[str] | None = None
|
||||
) -> tuple[str, dict]:
|
||||
robot_cfg = init_hydra_config(config_path, config_overrides)
|
||||
if robot_cfg["robot_type"] in ["aloha", "koch"]:
|
||||
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
|
||||
if robot_cfg.type in ["aloha", "koch"]:
|
||||
state_names = [
|
||||
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["follower_arms"]
|
||||
for motor in robot_cfg["follower_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
|
||||
for arm in robot_cfg.follower_arms
|
||||
for motor in robot_cfg.follower_arms[arm].motors
|
||||
]
|
||||
action_names = [
|
||||
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["leader_arms"]
|
||||
for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
|
||||
for arm in robot_cfg.leader_arms
|
||||
for motor in robot_cfg.leader_arms[arm].motors
|
||||
]
|
||||
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
||||
else:
|
||||
@@ -175,7 +173,7 @@ def parse_robot_config(
|
||||
)
|
||||
|
||||
return {
|
||||
"robot_type": robot_cfg["robot_type"],
|
||||
"robot_type": robot_cfg.type,
|
||||
"names": {
|
||||
"observation.state": state_names,
|
||||
"observation.effort": state_names,
|
||||
@@ -206,8 +204,9 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
|
||||
|
||||
def get_features_from_hf_dataset(
|
||||
dataset: Dataset, robot_config: dict | None = None
|
||||
dataset: Dataset, robot_config: RobotConfig | None = None
|
||||
) -> dict[str, list]:
|
||||
robot_config = parse_robot_config(robot_config)
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
@@ -219,9 +218,7 @@ def get_features_from_hf_dataset(
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
motor_names = (
|
||||
robot_config["names"][key]
|
||||
if robot_config
|
||||
else [f"motor_{i}" for i in range(ft.length)]
|
||||
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
)
|
||||
assert len(motor_names) == shape[0]
|
||||
names = {"motors": motor_names}
|
||||
@@ -230,11 +227,11 @@ def get_features_from_hf_dataset(
|
||||
image = dataset[0][key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
shape = (image.height, image.width, channels)
|
||||
names = ["height", "width", "channel"]
|
||||
names = ["height", "width", "channels"]
|
||||
elif ft._type == "VideoFrame":
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["height", "width", "channel"]
|
||||
names = ["height", "width", "channels"]
|
||||
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
@@ -245,15 +242,11 @@ def get_features_from_hf_dataset(
|
||||
return features
|
||||
|
||||
|
||||
def add_task_index_by_episodes(
|
||||
dataset: Dataset, tasks_by_episodes: dict
|
||||
) -> tuple[Dataset, list[str]]:
|
||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
tasks = list(set(tasks_by_episodes.values()))
|
||||
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
||||
episodes_to_task_index = {
|
||||
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
||||
|
||||
features = dataset.features
|
||||
@@ -270,19 +263,10 @@ def add_task_index_from_tasks_col(
|
||||
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
||||
prefix_to_clean = "tf.Tensor(b'"
|
||||
suffix_to_clean = "', shape=(), dtype=string)"
|
||||
df[tasks_col] = (
|
||||
df[tasks_col]
|
||||
.str.removeprefix(prefix_to_clean)
|
||||
.str.removesuffix(suffix_to_clean)
|
||||
)
|
||||
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
||||
|
||||
# Create task_index col
|
||||
tasks_by_episode = (
|
||||
df.groupby("episode_index")[tasks_col]
|
||||
.unique()
|
||||
.apply(lambda x: x.tolist())
|
||||
.to_dict()
|
||||
)
|
||||
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
||||
tasks = df[tasks_col].unique().tolist()
|
||||
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
||||
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
||||
@@ -307,9 +291,7 @@ def split_parquet_by_episodes(
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk
|
||||
)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
@@ -341,9 +323,7 @@ def move_videos(
|
||||
videos_moved = False
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||
if len(video_files) == 0:
|
||||
video_files = [
|
||||
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
|
||||
]
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
||||
videos_moved = True # Videos have already been moved
|
||||
|
||||
assert len(video_files) == total_episodes * len(video_keys)
|
||||
@@ -374,9 +354,7 @@ def move_videos(
|
||||
target_path = DEFAULT_VIDEO_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(
|
||||
video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
@@ -393,9 +371,7 @@ def move_videos(
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_lfs_video_files_tracking(
|
||||
work_dir: Path, lfs_untracked_videos: list[str]
|
||||
) -> None:
|
||||
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
|
||||
"""
|
||||
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||
@@ -403,12 +379,7 @@ def fix_lfs_video_files_tracking(
|
||||
for i in range(0, len(lfs_untracked_videos), 100):
|
||||
files = lfs_untracked_videos[i : i + 100]
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "rm", "--cached", *files],
|
||||
cwd=work_dir,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("git rm --cached ERROR:")
|
||||
print(e.stderr)
|
||||
@@ -419,14 +390,10 @@ def fix_lfs_video_files_tracking(
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_gitattributes(
|
||||
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
|
||||
) -> None:
|
||||
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
|
||||
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
||||
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
|
||||
)
|
||||
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
@@ -435,17 +402,7 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"--branch",
|
||||
branch,
|
||||
"--single-branch",
|
||||
"--depth",
|
||||
"1",
|
||||
repo_url,
|
||||
str(work_dir),
|
||||
],
|
||||
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
||||
check=True,
|
||||
env=env,
|
||||
)
|
||||
@@ -453,19 +410,13 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
|
||||
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
||||
lfs_tracked_files = subprocess.run(
|
||||
["git", "lfs", "ls-files", "-n"],
|
||||
cwd=work_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
||||
)
|
||||
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def get_videos_info(
|
||||
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
|
||||
) -> dict:
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
@@ -473,11 +424,7 @@ def get_videos_info(
|
||||
]
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
revision=branch,
|
||||
allow_patterns=video_files,
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
)
|
||||
videos_info_dict = {}
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
@@ -492,11 +439,11 @@ def convert_dataset(
|
||||
single_task: str | None = None,
|
||||
tasks_path: Path | None = None,
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: dict | None = None,
|
||||
robot_config: RobotConfig | None = None,
|
||||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16)
|
||||
v1 = get_safe_version(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -504,11 +451,7 @@ def convert_dataset(
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=v1,
|
||||
local_dir=v1x_dir,
|
||||
ignore_patterns="videos*/",
|
||||
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
||||
)
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
@@ -540,31 +483,19 @@ def convert_dataset(
|
||||
if single_task:
|
||||
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {
|
||||
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_col:
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
|
||||
dataset, tasks_col
|
||||
)
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
assert set(tasks) == {
|
||||
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
|
||||
}
|
||||
tasks = [
|
||||
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
|
||||
]
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
@@ -578,25 +509,14 @@ def convert_dataset(
|
||||
dataset = dataset.remove_columns(video_keys)
|
||||
clean_gitattr = Path(
|
||||
hub_api.hf_hub_download(
|
||||
repo_id=GITATTRIBUTES_REF,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
filename=".gitattributes",
|
||||
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||
)
|
||||
).absolute()
|
||||
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||
move_videos(
|
||||
repo_id,
|
||||
video_keys,
|
||||
total_episodes,
|
||||
total_chunks,
|
||||
Path(tmp_video_dir),
|
||||
clean_gitattr,
|
||||
branch,
|
||||
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
||||
)
|
||||
videos_info = get_videos_info(
|
||||
repo_id, v1x_dir, video_keys=video_keys, branch=branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.height"),
|
||||
@@ -604,25 +524,18 @@ def convert_dataset(
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
assert math.isclose(
|
||||
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
|
||||
)
|
||||
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
if "encoding" in metadata_v1:
|
||||
assert (
|
||||
videos_info[key]["video.pix_fmt"]
|
||||
== metadata_v1["encoding"]["pix_fmt"]
|
||||
)
|
||||
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(
|
||||
dataset, total_episodes, total_chunks, v20_dir
|
||||
)
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config["robot_type"]
|
||||
robot_type = robot_config.type
|
||||
repo_tags = [robot_type]
|
||||
else:
|
||||
robot_type = "unknown"
|
||||
@@ -630,11 +543,7 @@ def convert_dataset(
|
||||
|
||||
# Episodes
|
||||
episodes = [
|
||||
{
|
||||
"episode_index": ep_idx,
|
||||
"tasks": tasks_by_episodes[ep_idx],
|
||||
"length": episode_lengths[ep_idx],
|
||||
}
|
||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
@@ -657,27 +566,16 @@ def convert_dataset(
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
|
||||
)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
|
||||
)
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="meta_data",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
|
||||
)
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
@@ -726,16 +624,10 @@ def main():
|
||||
help="The path to a .json file containing one language instruction for each episode_index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-config",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Path to the robot's config yaml the dataset during conversion.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
"--robot",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
|
||||
default=None,
|
||||
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
@@ -760,12 +652,10 @@ def main():
|
||||
if not args.local_dir:
|
||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||
|
||||
robot_config = (
|
||||
parse_robot_config(args.robot_config, args.robot_overrides)
|
||||
if args.robot_config
|
||||
else None
|
||||
)
|
||||
del args.robot_config, args.robot_overrides
|
||||
if args.robot is not None:
|
||||
robot_config = make_robot_config(args.robot)
|
||||
|
||||
del args.robot
|
||||
|
||||
convert_dataset(**vars(args), robot_config=robot_config)
|
||||
|
||||
|
||||
87
lerobot/common/datasets/v21/_remove_language_instruction.py
Normal file
87
lerobot/common/datasets/v21/_remove_language_instruction.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import get_dataset_config_info
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import INFO_PATH, write_info
|
||||
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
hub_api = HfApi()
|
||||
|
||||
|
||||
def fix_dataset(repo_id: str) -> str:
|
||||
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
|
||||
return f"{repo_id}: skipped (not in {V20})."
|
||||
|
||||
dataset_info = get_dataset_config_info(repo_id, "default")
|
||||
with SuppressWarnings():
|
||||
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
||||
parquet_features = set(dataset_info.features)
|
||||
|
||||
diff_parquet_meta = parquet_features - meta_features
|
||||
diff_meta_parquet = meta_features - parquet_features
|
||||
|
||||
if diff_parquet_meta:
|
||||
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
||||
|
||||
if not diff_meta_parquet:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
commit_info = hub_api.upload_file(
|
||||
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
|
||||
path_in_repo=INFO_PATH,
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V20,
|
||||
commit_message="Remove 'language_instruction'",
|
||||
create_pr=True,
|
||||
)
|
||||
return f"{repo_id}: success - PR: {commit_info.pr_url}"
|
||||
|
||||
|
||||
def batch_fix():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "fix_features_v20.txt"
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
status = fix_dataset(repo_id)
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
logging.info(status)
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_fix()
|
||||
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
||||
hub_api = HfApi()
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
|
||||
status = f"{repo_id}: success (already in {V21})."
|
||||
else:
|
||||
convert_dataset(repo_id)
|
||||
status = f"{repo_id}: success."
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
||||
114
lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
Normal file
114
lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
|
||||
|
||||
class SuppressWarnings:
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logging.getLogger().setLevel(self.previous_level)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
99
lerobot/common/datasets/v21/convert_stats.py
Normal file
99
lerobot/common/datasets/v21/convert_stats.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
||||
return video_frames[ft_key].numpy()
|
||||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
||||
ep_stats = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if ft["dtype"] == "video":
|
||||
# We sample only for videos
|
||||
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
||||
else:
|
||||
ep_ft_data = np.array(ep_data[key])
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
|
||||
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
assert dataset.episodes is None
|
||||
print("Computing episodes stats")
|
||||
total_episodes = dataset.meta.total_episodes
|
||||
if num_workers > 0:
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
||||
for ep_idx in range(total_episodes)
|
||||
}
|
||||
for future in tqdm(as_completed(futures), total=total_episodes):
|
||||
future.result()
|
||||
else:
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
convert_episode_stats(dataset, ep_idx)
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
dataset: LeRobotDataset,
|
||||
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
||||
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
||||
):
|
||||
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
||||
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
||||
for key, ft in dataset.features.items():
|
||||
# These values might need some fine-tuning
|
||||
if ft["dtype"] == "video":
|
||||
# to account for image sub-sampling
|
||||
rtol, atol = video_rtol_atol
|
||||
else:
|
||||
rtol, atol = default_rtol_atol
|
||||
|
||||
for stat, val in agg_stats[key].items():
|
||||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(
|
||||
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||
)
|
||||
@@ -69,11 +69,11 @@ def decode_video_frames_torchvision(
|
||||
|
||||
# set the first and last requested timestamps
|
||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||
first_ts = timestamps[0]
|
||||
last_ts = timestamps[-1]
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
# access closest key frame of the first requested frame
|
||||
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||
reader.seek(first_ts, keyframes_only=keyframes_only)
|
||||
|
||||
@@ -227,9 +227,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(
|
||||
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
@@ -243,9 +241,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"])
|
||||
if audio_stream_info.get("bit_rate")
|
||||
else None,
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
@@ -267,9 +263,7 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(
|
||||
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,9 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlConfig:
|
||||
pass
|
||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
||||
156
lerobot/common/envs/configs.py
Normal file
156
lerobot/common/envs/configs.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
task: str | None = None
|
||||
fps: int = 30
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
def gym_kwargs(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaEnv(EnvConfig):
|
||||
task: str = "AlohaInsertion-v0"
|
||||
fps: int = 50
|
||||
episode_length: int = 400
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("pusht")
|
||||
@dataclass
|
||||
class PushtEnv(EnvConfig):
|
||||
task: str = "PushT-v0"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("xarm")
|
||||
@dataclass
|
||||
class XarmEnv(EnvConfig):
|
||||
task: str = "XarmLift-v0"
|
||||
fps: int = 15
|
||||
episode_length: int = 200
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
@@ -14,155 +14,56 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from collections import deque
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
# from mani_skill.utils import common
|
||||
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the evaluation config.
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
if env_type == "aloha":
|
||||
return AlohaEnv(**kwargs)
|
||||
elif env_type == "pusht":
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||
|
||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
cfg (EnvConfig): the config of the environment to instantiate.
|
||||
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
||||
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
||||
False.
|
||||
|
||||
Raises:
|
||||
ValueError: if n_envs < 1
|
||||
ModuleNotFoundError: If the requested env package is not installed
|
||||
|
||||
Returns:
|
||||
gym.vector.VectorEnv: The parallelized gym.env instance.
|
||||
"""
|
||||
if n_envs is not None and n_envs < 1:
|
||||
if n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
if "maniskill" in cfg.env.name:
|
||||
env = make_maniskill_env(
|
||||
cfg, n_envs if n_envs is not None else cfg.eval.batch_size
|
||||
)
|
||||
return env
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
package_name = f"gym_{cfg.type}"
|
||||
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
|
||||
)
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
gym_kwgs = dict(cfg.env.get("gym", {}))
|
||||
|
||||
if cfg.env.get("episode_length"):
|
||||
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = (
|
||||
gym.vector.AsyncVectorEnv
|
||||
if cfg.eval.use_async_envs
|
||||
else gym.vector.SyncVectorEnv
|
||||
)
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
]
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(
|
||||
cfg: DictConfig, n_envs: int | None = None
|
||||
) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
env = gym.make(
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
||||
num_envs=n_envs,
|
||||
)
|
||||
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
||||
# state should have the size of 25
|
||||
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||
# env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = (
|
||||
50 # gym_utils.find_max_episode_steps_value(env)
|
||||
)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class PixelWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env, num_envs, num_frames=3):
|
||||
super().__init__(env)
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
self._frames = deque([], maxlen=num_frames)
|
||||
self._render_size = cfg.env.render_size
|
||||
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {
|
||||
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
|
||||
self.env.device
|
||||
)
|
||||
}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
for _ in range(self._frames.maxlen):
|
||||
obs_frames = self._get_obs(obs)
|
||||
return obs_frames, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
|
||||
# TODO: Remove this
|
||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
def __init__(self, env, num_envs):
|
||||
super().__init__(env)
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options={})
|
||||
return self._get_obs(obs), info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self, observation):
|
||||
sensor_data = observation.pop("sensor_data")
|
||||
del observation["sensor_param"]
|
||||
images = []
|
||||
for cam_data in sensor_data.values():
|
||||
images.append(cam_data["rgb"])
|
||||
|
||||
images = torch.concat(images, axis=-1)
|
||||
# flatten the rest of the data which should just be state data
|
||||
observation = common.flatten_state_dict(
|
||||
observation, use_torch=True, device=self.base_env.device
|
||||
)
|
||||
ret = dict()
|
||||
ret["state"] = observation
|
||||
ret["pixels"] = images
|
||||
return ret
|
||||
|
||||
@@ -18,8 +18,13 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.utils.utils import get_channel_first_image_shape
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
@@ -28,32 +33,29 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
for key, img in observations.items():
|
||||
if "images" not in key:
|
||||
continue
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"expect channel last images, but instead got {img.shape=}"
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
return_observations[key] = img
|
||||
# obs state agent qpos and qvel
|
||||
# image
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
return_observations[imgkey] = img
|
||||
|
||||
if "environment_state" in observations:
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
@@ -62,43 +64,25 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return_observations["observation.state"] = observations["observation.state"].float()
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return return_observations
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
q_pos = observations["agent"]["qpos"]
|
||||
q_vel = observations["agent"]["qvel"]
|
||||
tcp_pos = observations["extra"]["tcp_pose"]
|
||||
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||
policy_features = {}
|
||||
for key, ft in env_cfg.features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
if len(ft.shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
|
||||
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
shape = get_channel_first_image_shape(ft.shape)
|
||||
feature = PolicyFeature(type=ft.type, shape=shape)
|
||||
else:
|
||||
feature = ft
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
policy_key = env_cfg.features_map[key]
|
||||
policy_features[policy_key] = feature
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
return policy_features
|
||||
|
||||
17
lerobot/common/errors.py
Normal file
17
lerobot/common/errors.py
Normal file
@@ -0,0 +1,17 @@
|
||||
class DeviceNotConnectedError(ConnectionError):
|
||||
"""Exception raised when the device is not connected."""
|
||||
|
||||
def __init__(self, message="This device is not connected. Try calling `connect()` first."):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class DeviceAlreadyConnectedError(ConnectionError):
|
||||
"""Exception raised when the device is already connected."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message="This device is already connected. Try not calling `connect()` twice.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
@@ -1,329 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
|
||||
|
||||
# TODO(rcadene, alexander-soare): clean this file
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.name}",
|
||||
f"dataset:{cfg.dataset_repo_id}",
|
||||
f"env:{cfg.env.name}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
|
||||
# Get the WandB run ID.
|
||||
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
|
||||
if len(paths) != 1:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||
if match is None:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
wandb_run_id = match.groups(0)[0]
|
||||
return wandb_run_id
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb.
|
||||
|
||||
The logger creates the following directory structure:
|
||||
|
||||
provided_log_dir
|
||||
├── .hydra # hydra's configuration cache
|
||||
├── checkpoints
|
||||
│ ├── specific_checkpoint_name
|
||||
│ │ ├── pretrained_model # Hugging Face pretrained model directory
|
||||
│ │ │ ├── ...
|
||||
│ │ └── training_state.pth # optimizer, scheduler, and random states + training step
|
||||
| ├── another_specific_checkpoint_name
|
||||
│ │ ├── ...
|
||||
| ├── ...
|
||||
│ └── last # a softlink to the last logged checkpoint
|
||||
"""
|
||||
|
||||
pretrained_model_dir_name = "pretrained_model"
|
||||
training_state_file_name = "training_state.pth"
|
||||
|
||||
def __init__(
|
||||
self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
log_dir: The directory to save all logs and training outputs to.
|
||||
job_name: The WandB job name.
|
||||
"""
|
||||
self._cfg = cfg
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
|
||||
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
|
||||
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
|
||||
|
||||
# Set up WandB.
|
||||
self._group = cfg_to_group(cfg)
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project
|
||||
if run_offline:
|
||||
logging.info(
|
||||
colored("Logs will be saved locally.", "yellow", attrs=["bold"])
|
||||
)
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=project,
|
||||
entity=entity,
|
||||
name=wandb_job_name,
|
||||
notes=cfg.get("wandb", {}).get("notes"),
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=log_dir,
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
# TODO(rcadene): try set to True
|
||||
save_code=False,
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(
|
||||
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
|
||||
)
|
||||
self._wandb = wandb
|
||||
|
||||
@classmethod
|
||||
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
|
||||
return Path(log_dir) / "checkpoints"
|
||||
|
||||
@classmethod
|
||||
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
|
||||
return cls.get_checkpoints_dir(log_dir) / "last"
|
||||
|
||||
@classmethod
|
||||
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""
|
||||
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
|
||||
be saved.
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(
|
||||
self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None
|
||||
):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
|
||||
Optionally also upload the model to WandB.
|
||||
"""
|
||||
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._cfg.wandb.disable_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
if self.last_checkpoint_dir.exists():
|
||||
os.remove(self.last_checkpoint_dir)
|
||||
|
||||
def save_training_state(
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer | dict,
|
||||
scheduler: LRScheduler | None,
|
||||
interaction_step: int | None = None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
# In Sac, for example, we have a dictionary of torch.optim.Optimizer
|
||||
if type(optimizer) is dict:
|
||||
optimizer_state_dict = {}
|
||||
for k in optimizer:
|
||||
optimizer_state_dict[k] = optimizer[k].state_dict()
|
||||
else:
|
||||
optimizer_state_dict = optimizer.state_dict()
|
||||
|
||||
training_state = {
|
||||
"step": train_step,
|
||||
"optimizer": optimizer_state_dict,
|
||||
**get_global_random_state(),
|
||||
}
|
||||
# Interaction step is related to the distributed training code
|
||||
# In that setup, we have two kinds of steps, the online step of the env and the optimization step
|
||||
# We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step
|
||||
if interaction_step is not None:
|
||||
training_state["interaction_step"] = interaction_step
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
train_step: int,
|
||||
policy: Policy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
identifier: str,
|
||||
interaction_step: int | None = None,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
wandb_artifact_name = (
|
||||
None
|
||||
if self._wandb is None
|
||||
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
|
||||
)
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name,
|
||||
policy,
|
||||
wandb_artifact_name=wandb_artifact_name,
|
||||
)
|
||||
self.save_training_state(
|
||||
checkpoint_dir, train_step, optimizer, scheduler, interaction_step
|
||||
)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(
|
||||
self, optimizer: Optimizer | dict, scheduler: LRScheduler | None
|
||||
) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(
|
||||
self.last_checkpoint_dir / self.training_state_file_name
|
||||
)
|
||||
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
|
||||
if type(training_state["optimizer"]) is dict:
|
||||
assert set(training_state["optimizer"].keys()) == set(
|
||||
optimizer.keys()
|
||||
), "Optimizer dictionaries do not have the same keys during resume!"
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizer[k].load_state_dict(v)
|
||||
else:
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
raise ValueError(
|
||||
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
|
||||
)
|
||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state(
|
||||
{k: training_state[k] for k in get_global_random_state()}
|
||||
)
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(
|
||||
self,
|
||||
d,
|
||||
step: int | None = None,
|
||||
mode="train",
|
||||
custom_step_key: str | None = None,
|
||||
):
|
||||
"""Log a dictionary of metrics to WandB."""
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
if self._wandb is not None:
|
||||
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
# multiple time steps is possible for example, the interaction step with the environment,
|
||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||
# to log the correct step for each metric.
|
||||
if custom_step_key is not None:
|
||||
if self._wandb_custom_step_key is None:
|
||||
self._wandb_custom_step_key = set()
|
||||
new_custom_key = f"{mode}/{custom_step_key}"
|
||||
if new_custom_key not in self._wandb_custom_step_key:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str, wandb.Table)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if (
|
||||
self._wandb_custom_step_key is not None
|
||||
and k in self._wandb_custom_step_key
|
||||
):
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
self._wandb.log(
|
||||
{
|
||||
f"{mode}/{k}": v,
|
||||
f"{mode}/{custom_step_key}": value_custom_step,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
3
lerobot/common/motors/__init__.py
Normal file
3
lerobot/common/motors/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .motors_bus import MotorsBus
|
||||
|
||||
__all__ = ["MotorsBus"]
|
||||
41
lerobot/common/motors/configs.py
Normal file
41
lerobot/common/motors/configs.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class MotorsBusConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("dynamixel")
|
||||
@dataclass
|
||||
class DynamixelMotorsBusConfig(MotorsBusConfig):
|
||||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("feetech")
|
||||
@dataclass
|
||||
class FeetechMotorsBusConfig(MotorsBusConfig):
|
||||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
4
lerobot/common/motors/dynamixel/__init__.py
Normal file
4
lerobot/common/motors/dynamixel/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .dynamixel import DynamixelMotorsBus, TorqueMode, set_operating_mode
|
||||
from .dynamixel_calibration import run_arm_calibration
|
||||
|
||||
__all__ = ["DynamixelMotorsBus", "TorqueMode", "set_operating_mode", "run_arm_calibration"]
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
@@ -8,10 +22,7 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
@@ -146,9 +157,7 @@ NUM_READ_RETRY = 10
|
||||
NUM_WRITE_RETRY = 10
|
||||
|
||||
|
||||
def convert_degrees_to_steps(
|
||||
degrees: float | np.ndarray, models: str | list[str]
|
||||
) -> np.ndarray:
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
@@ -246,7 +255,7 @@ class DriveMode(enum.Enum):
|
||||
class CalibrationMode(enum.Enum):
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
DEGREE = 0
|
||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
LINEAR = 1
|
||||
|
||||
|
||||
@@ -257,7 +266,6 @@ class JointOutOfRangeError(Exception):
|
||||
|
||||
|
||||
class DynamixelMotorsBus:
|
||||
# TODO(rcadene): Add a script to find the motor indices without DynamixelWizzard2
|
||||
"""
|
||||
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
|
||||
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
|
||||
@@ -300,21 +308,14 @@ class DynamixelMotorsBus:
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
extra_model_resolution: dict[str, int] | None = None,
|
||||
mock=False,
|
||||
mock: bool = False,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.mock = mock
|
||||
|
||||
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
if extra_model_resolution:
|
||||
self.model_resolution.update(extra_model_resolution)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
@@ -326,12 +327,12 @@ class DynamixelMotorsBus:
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -355,7 +356,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -383,9 +384,7 @@ class DynamixelMotorsBus:
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(
|
||||
self.motor_models, [idx], "ID", num_retry=num_retry
|
||||
)[0]
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
@@ -401,9 +400,7 @@ class DynamixelMotorsBus:
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(
|
||||
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
|
||||
)
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
@@ -424,9 +421,7 @@ class DynamixelMotorsBus:
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
@@ -439,9 +434,7 @@ class DynamixelMotorsBus:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
@@ -516,9 +509,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
@@ -560,23 +551,15 @@ class DynamixelMotorsBus:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (
|
||||
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
|
||||
calib_val < UPPER_BOUND_DEGREE
|
||||
)
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-(resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
(resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
|
||||
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
@@ -584,9 +567,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
|
||||
calib_val < UPPER_BOUND_LINEAR
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
@@ -602,27 +583,19 @@ class DynamixelMotorsBus:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
in_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
@@ -632,9 +605,7 @@ class DynamixelMotorsBus:
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
@@ -653,7 +624,7 @@ class DynamixelMotorsBus:
|
||||
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
||||
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
||||
|
||||
# Substract the homing offsets to come back to actual motor range of values
|
||||
# Subtract the homing offsets to come back to actual motor range of values
|
||||
# which can be arbitrary.
|
||||
values[i] -= homing_offset
|
||||
|
||||
@@ -673,11 +644,9 @@ class DynamixelMotorsBus:
|
||||
values = np.round(values).astype(np.int32)
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
|
||||
):
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -715,14 +684,14 @@ class DynamixelMotorsBus:
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -777,9 +746,7 @@ class DynamixelMotorsBus:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "read", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
@@ -788,11 +755,9 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
|
||||
):
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -819,21 +784,16 @@ class DynamixelMotorsBus:
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name,
|
||||
values: int | float | np.ndarray,
|
||||
motor_names: str | list[str] | None = None,
|
||||
):
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -885,9 +845,7 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "write", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
@@ -897,7 +855,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
|
||||
)
|
||||
|
||||
@@ -913,3 +871,25 @@ class DynamixelMotorsBus:
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
|
||||
|
||||
def set_operating_mode(arm: DynamixelMotorsBus):
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
|
||||
|
||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
|
||||
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
# 4 corresponds to Extended Position on Koch motors
|
||||
arm.write("Operating_Mode", 4, all_motors_except_gripper)
|
||||
|
||||
# Use 'position control current based' for gripper to be limited by the limit of the current.
|
||||
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
|
||||
# it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
|
||||
# to make it move, and it will move back to its original target position when we release the force.
|
||||
# 5 corresponds to Current Controlled Position on Koch gripper motors "xl330-m077, xl330-m288"
|
||||
arm.write("Operating_Mode", 5, "gripper")
|
||||
@@ -1,16 +1,32 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Logic to calibrate a robot arm built with dynamixel motors"""
|
||||
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
from ..motors_bus import MotorsBus
|
||||
from .dynamixel import (
|
||||
CalibrationMode,
|
||||
TorqueMode,
|
||||
convert_degrees_to_steps,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
|
||||
# The following positions are provided in nominal degree range ]-180, +180[
|
||||
# For more info on these constants, see comments in the code where they get used.
|
||||
@@ -21,9 +37,7 @@ ROTATED_POSITION_DEGREE = 90
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(
|
||||
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
|
||||
)
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
@@ -64,16 +78,12 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
```
|
||||
"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
@@ -91,18 +101,13 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
||||
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
||||
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(
|
||||
ROTATED_POSITION_DEGREE, arm.motor_models
|
||||
)
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
@@ -111,15 +116,11 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
|
||||
# Re-compute homing offset to take into account drive mode
|
||||
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(
|
||||
rotated_drived_pos, arm.motor_models
|
||||
)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
|
||||
homing_offset = rotated_target_pos - rotated_nearest_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
@@ -128,7 +129,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
|
||||
# TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
|
||||
if robot_type in ["aloha"] and "gripper" in arm.motor_names:
|
||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
calib_idx = arm.motor_names.index("gripper")
|
||||
calib_mode[calib_idx] = CalibrationMode.LINEAR.name
|
||||
|
||||
9
lerobot/common/motors/feetech/__init__.py
Normal file
9
lerobot/common/motors/feetech/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .feetech import FeetechMotorsBus, TorqueMode
|
||||
from .feetech_calibration import apply_feetech_offsets_from_calibration, run_full_arm_calibration
|
||||
|
||||
__all__ = [
|
||||
"FeetechMotorsBus",
|
||||
"TorqueMode",
|
||||
"apply_feetech_offsets_from_calibration",
|
||||
"run_full_arm_calibration",
|
||||
]
|
||||
@@ -1,6 +1,18 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
@@ -8,10 +20,7 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 0
|
||||
@@ -20,13 +29,6 @@ TIMEOUT_MS = 1000
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
# The following bounds define the lower and upper joints range (after calibration).
|
||||
# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees
|
||||
# which corresponds to a half rotation on the left and half rotation on the right.
|
||||
# Some joints might require higher range, so we allow up to [-270, 270] degrees until
|
||||
# an error is raised.
|
||||
LOWER_BOUND_DEGREE = -270
|
||||
UPPER_BOUND_DEGREE = 270
|
||||
# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper),
|
||||
# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully
|
||||
# closed, and 100% is fully open. To account for slight calibration issue, we allow up to
|
||||
@@ -36,7 +38,6 @@ UPPER_BOUND_LINEAR = 110
|
||||
|
||||
HALF_TURN_DEGREE = 180
|
||||
|
||||
|
||||
# See this link for STS3215 Memory Table:
|
||||
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
|
||||
# data_name: (address, size_byte)
|
||||
@@ -102,8 +103,6 @@ SCS_SERIES_BAUDRATE_TABLE = {
|
||||
}
|
||||
|
||||
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
|
||||
|
||||
MODEL_CONTROL_TABLE = {
|
||||
"scs_series": SCS_SERIES_CONTROL_TABLE,
|
||||
@@ -125,17 +124,63 @@ NUM_READ_RETRY = 20
|
||||
NUM_WRITE_RETRY = 20
|
||||
|
||||
|
||||
def convert_degrees_to_steps(
|
||||
degrees: float | np.ndarray, models: str | list[str]
|
||||
) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
def convert_ticks_to_degrees(ticks, model):
|
||||
resolutions = MODEL_RESOLUTION[model]
|
||||
# Convert the ticks to degrees
|
||||
return ticks * (360.0 / resolutions)
|
||||
|
||||
|
||||
def convert_degrees_to_ticks(degrees, model):
|
||||
resolutions = MODEL_RESOLUTION[model]
|
||||
# Convert degrees to motor ticks
|
||||
return int(degrees * (resolutions / 360.0))
|
||||
|
||||
|
||||
def adjusted_to_homing_ticks(raw_motor_ticks: int, model: str, motorbus, motor_id: int) -> int:
|
||||
"""
|
||||
resolutions = [MODEL_RESOLUTION[model] for model in models]
|
||||
steps = degrees / 180 * np.array(resolutions) / 2
|
||||
steps = steps.astype(int)
|
||||
return steps
|
||||
Takes a raw reading [0..(res-1)] (e.g. 0..4095) and shifts it so that '2048'
|
||||
becomes 0 in the homed coordinate system ([-2048..+2047] for 4096 resolution).
|
||||
"""
|
||||
resolutions = MODEL_RESOLUTION[model]
|
||||
|
||||
# Shift raw ticks by half-resolution so 2048 -> 0, then wrap [0..res-1].
|
||||
ticks = (raw_motor_ticks - (resolutions // 2)) % resolutions
|
||||
|
||||
# If above halfway, fold it into negative territory => [-2048..+2047].
|
||||
if ticks > (resolutions // 2):
|
||||
ticks -= resolutions
|
||||
|
||||
# Flip sign if drive_mode is set.
|
||||
drive_mode = 0
|
||||
if motorbus.calibration is not None:
|
||||
drive_mode = motorbus.calibration["drive_mode"][motor_id - 1]
|
||||
|
||||
if drive_mode:
|
||||
ticks *= -1
|
||||
|
||||
return ticks
|
||||
|
||||
|
||||
def adjusted_to_motor_ticks(adjusted_pos: int, model: str, motorbus, motor_id: int) -> int:
|
||||
"""
|
||||
Inverse of adjusted_to_homing_ticks(). Takes a 'homed' position in [-2048..+2047]
|
||||
and recovers the raw [0..(res-1)] ticks with 2048 as midpoint.
|
||||
"""
|
||||
# Flip sign if drive_mode was set.
|
||||
drive_mode = 0
|
||||
if motorbus.calibration is not None:
|
||||
drive_mode = motorbus.calibration["drive_mode"][motor_id - 1]
|
||||
|
||||
if drive_mode:
|
||||
adjusted_pos *= -1
|
||||
|
||||
resolutions = MODEL_RESOLUTION[model]
|
||||
|
||||
# Shift by +half-resolution and wrap into [0..res-1].
|
||||
# This undoes the earlier shift by -half-resolution.
|
||||
ticks = (adjusted_pos + (resolutions // 2)) % resolutions
|
||||
|
||||
return ticks
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes, mock=False):
|
||||
@@ -225,7 +270,7 @@ class DriveMode(enum.Enum):
|
||||
class CalibrationMode(enum.Enum):
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
DEGREE = 0
|
||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
LINEAR = 1
|
||||
|
||||
|
||||
@@ -278,21 +323,14 @@ class FeetechMotorsBus:
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
extra_model_resolution: dict[str, int] | None = None,
|
||||
mock=False,
|
||||
mock: bool = False,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.mock = mock
|
||||
|
||||
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
if extra_model_resolution:
|
||||
self.model_resolution.update(extra_model_resolution)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
@@ -302,16 +340,14 @@ class FeetechMotorsBus:
|
||||
self.group_writers = {}
|
||||
self.logs = {}
|
||||
|
||||
self.track_positions = {}
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -335,7 +371,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -363,9 +399,7 @@ class FeetechMotorsBus:
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(
|
||||
self.motor_models, [idx], "ID", num_retry=num_retry
|
||||
)[0]
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
@@ -381,9 +415,7 @@ class FeetechMotorsBus:
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(
|
||||
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
|
||||
)
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
@@ -404,37 +436,7 @@ class FeetechMotorsBus:
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
"""
|
||||
try:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
except JointOutOfRangeError as e:
|
||||
print(e)
|
||||
self.autocorrect_calibration(values, motor_names)
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
|
||||
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
|
||||
|
||||
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
|
||||
when given a goal position that is + or - their resolution. For instance, feetech xl330-m077 have a resolution of 4096, and
|
||||
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
|
||||
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
|
||||
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
|
||||
in the centered nominal degree range ]-180, 180[.
|
||||
"""
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
@@ -446,34 +448,11 @@ class FeetechMotorsBus:
|
||||
calib_mode = self.calibration["calib_mode"][calib_idx]
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
drive_mode = self.calibration["drive_mode"][calib_idx]
|
||||
homing_offset = self.calibration["homing_offset"][calib_idx]
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
motor_idx, model = self.motors[name]
|
||||
|
||||
# Update direction of rotation of the motor to match between leader and follower.
|
||||
# In fact, the motor of the leader for a given joint can be assembled in an
|
||||
# opposite direction in term of rotation than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from range [-2**31, 2**31[ to
|
||||
# nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[)
|
||||
values[i] += homing_offset
|
||||
|
||||
# Convert from range ]-resolution, resolution[ to
|
||||
# universal float32 centered degree range ]-180, 180[
|
||||
values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE
|
||||
|
||||
if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE):
|
||||
raise JointOutOfRangeError(
|
||||
f"Wrong motor position range detected for {name}. "
|
||||
f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), "
|
||||
f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, "
|
||||
f"but present value is {values[i]} degree. "
|
||||
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
|
||||
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
||||
)
|
||||
# Convert raw motor ticks to homed ticks, then convert the homed ticks to degrees
|
||||
values[i] = adjusted_to_homing_ticks(values[i], model, self, motor_idx)
|
||||
values[i] = convert_ticks_to_degrees(values[i], model)
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
@@ -495,126 +474,7 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given
|
||||
a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position.
|
||||
|
||||
Known issues:
|
||||
#1: Motor value randomly shifts of a full turn, caused by hardware/connection errors.
|
||||
#2: Motor internal homing offset is shifted of a full turn, caused by using default calibration (e.g Aloha).
|
||||
#3: motor internal homing offset is shifted of less or more than a full turn, caused by using default calibration
|
||||
or by human error during manual calibration.
|
||||
|
||||
Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn.
|
||||
Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`,
|
||||
that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue.
|
||||
|
||||
Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
"""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
|
||||
values = values.astype(np.float32)
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
calib_idx = self.calibration["motor_names"].index(name)
|
||||
calib_mode = self.calibration["calib_mode"][calib_idx]
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
drive_mode = self.calibration["drive_mode"][calib_idx]
|
||||
homing_offset = self.calibration["homing_offset"][calib_idx]
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (
|
||||
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
|
||||
calib_val < UPPER_BOUND_DEGREE
|
||||
)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
|
||||
- values[i]
|
||||
- homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
|
||||
- values[i]
|
||||
- homing_offset
|
||||
) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
end_pos = self.calibration["end_pos"][calib_idx]
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
|
||||
calib_val < UPPER_BOUND_LINEAR
|
||||
)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100
|
||||
# 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100
|
||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||
low_factor = (start_pos - values[i]) / resolution
|
||||
upp_factor = (end_pos - values[i]) / resolution
|
||||
|
||||
if not in_range:
|
||||
# Get first integer between the two bounds
|
||||
if low_factor < upp_factor:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
in_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
||||
)
|
||||
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
@@ -624,23 +484,11 @@ class FeetechMotorsBus:
|
||||
calib_mode = self.calibration["calib_mode"][calib_idx]
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
drive_mode = self.calibration["drive_mode"][calib_idx]
|
||||
homing_offset = self.calibration["homing_offset"][calib_idx]
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
motor_idx, model = self.motors[name]
|
||||
|
||||
# Convert from nominal 0-centered degree range [-180, 180] to
|
||||
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
||||
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
||||
|
||||
# Substract the homing offsets to come back to actual motor range of values
|
||||
# which can be arbitrary.
|
||||
values[i] -= homing_offset
|
||||
|
||||
# Remove drive mode, which is the rotation direction of the motor, to come back to
|
||||
# actual motor rotation direction which can be arbitrary.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
# Convert degrees to homed ticks, then convert the homed ticks to raw ticks
|
||||
values[i] = convert_degrees_to_ticks(values[i], model)
|
||||
values[i] = adjusted_to_motor_ticks(values[i], model, self, motor_idx)
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
@@ -653,48 +501,9 @@ class FeetechMotorsBus:
|
||||
values = np.round(values).astype(np.int32)
|
||||
return values
|
||||
|
||||
def avoid_rotation_reset(self, values, motor_names, data_name):
|
||||
if data_name not in self.track_positions:
|
||||
self.track_positions[data_name] = {
|
||||
"prev": [None] * len(self.motor_names),
|
||||
# Assume False at initialization
|
||||
"below_zero": [False] * len(self.motor_names),
|
||||
"above_max": [False] * len(self.motor_names),
|
||||
}
|
||||
|
||||
track = self.track_positions[data_name]
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
idx = self.motor_names.index(name)
|
||||
|
||||
if track["prev"][idx] is None:
|
||||
track["prev"][idx] = values[i]
|
||||
continue
|
||||
|
||||
# Detect a full rotation occured
|
||||
if abs(track["prev"][idx] - values[i]) > 2048:
|
||||
# Position went below 0 and got reset to 4095
|
||||
if track["prev"][idx] < values[i]:
|
||||
# So we set negative value by adding a full rotation
|
||||
values[i] -= 4096
|
||||
|
||||
# Position went above 4095 and got reset to 0
|
||||
elif track["prev"][idx] > values[i]:
|
||||
# So we add a full rotation
|
||||
values[i] += 4096
|
||||
|
||||
track["prev"][idx] = values[i]
|
||||
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
|
||||
):
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -732,12 +541,12 @@ class FeetechMotorsBus:
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
@@ -761,7 +570,11 @@ class FeetechMotorsBus:
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
if data_name not in self.group_readers:
|
||||
# create new group reader
|
||||
# Very Important to flush the buffer!
|
||||
self.port_handler.ser.reset_output_buffer()
|
||||
self.port_handler.ser.reset_input_buffer()
|
||||
|
||||
# Create new group reader
|
||||
self.group_readers[group_key] = scs.GroupSyncRead(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
)
|
||||
@@ -786,20 +599,11 @@ class FeetechMotorsBus:
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
# Convert to signed int to use range [-2048, 2048] for our motor positions.
|
||||
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
|
||||
values = values.astype(np.int32)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED:
|
||||
values = self.avoid_rotation_reset(values, motor_names, data_name)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "read", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
@@ -808,11 +612,9 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
|
||||
):
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -839,21 +641,16 @@ class FeetechMotorsBus:
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name,
|
||||
values: int | float | np.ndarray,
|
||||
motor_names: str | list[str] | None = None,
|
||||
):
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -905,9 +702,7 @@ class FeetechMotorsBus:
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "write", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
@@ -917,7 +712,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
|
||||
)
|
||||
|
||||
254
lerobot/common/motors/feetech/feetech_calibration.py
Normal file
254
lerobot/common/motors/feetech/feetech_calibration.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from ..motors_bus import MotorsBus
|
||||
from .feetech import (
|
||||
CalibrationMode,
|
||||
FeetechMotorsBus,
|
||||
TorqueMode,
|
||||
)
|
||||
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
|
||||
|
||||
def disable_torque(arm: MotorsBus):
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
|
||||
def get_calibration_modes(arm: MotorsBus):
|
||||
"""Returns calibration modes for each motor (DEGREE for rotational, LINEAR for gripper)."""
|
||||
return [
|
||||
CalibrationMode.LINEAR.name if name == "gripper" else CalibrationMode.DEGREE.name
|
||||
for name in arm.motor_names
|
||||
]
|
||||
|
||||
|
||||
def reset_offset(motor_id, motor_bus):
|
||||
# Open the write lock, changes to EEPROM do NOT persist yet
|
||||
motor_bus.write("Lock", 1)
|
||||
|
||||
# Set offset to 0
|
||||
motor_name = motor_bus.motor_names[motor_id - 1]
|
||||
motor_bus.write("Offset", 0, motor_names=[motor_name])
|
||||
|
||||
# Close the write lock, changes to EEPROM do persist
|
||||
motor_bus.write("Lock", 0)
|
||||
|
||||
# Confirm that the offset is zero by reading it back
|
||||
confirmed_offset = motor_bus.read("Offset")[motor_id - 1]
|
||||
print(f"Offset for motor {motor_id} reset to: {confirmed_offset}")
|
||||
return confirmed_offset
|
||||
|
||||
|
||||
def calibrate_homing_motor(motor_id, motor_bus):
|
||||
reset_offset(motor_id, motor_bus)
|
||||
|
||||
home_ticks = motor_bus.read("Present_Position")[motor_id - 1] # Read index starts at 0
|
||||
print(f"Encoder offset (present position in homing position): {home_ticks}")
|
||||
|
||||
return home_ticks
|
||||
|
||||
|
||||
def calibrate_linear_motor(motor_id, motor_bus):
|
||||
motor_names = motor_bus.motor_names
|
||||
motor_name = motor_names[motor_id - 1]
|
||||
|
||||
reset_offset(motor_id, motor_bus)
|
||||
|
||||
input(f"Close the {motor_name}, then press Enter...")
|
||||
start_pos = motor_bus.read("Present_Position")[motor_id - 1] # Read index starts ar 0
|
||||
print(f" [Motor {motor_id}] start position recorded: {start_pos}")
|
||||
|
||||
input(f"Open the {motor_name} fully, then press Enter...")
|
||||
end_pos = motor_bus.read("Present_Position")[motor_id - 1] # Read index starts ar 0
|
||||
print(f" [Motor {motor_id}] end position recorded: {end_pos}")
|
||||
|
||||
return start_pos, end_pos
|
||||
|
||||
|
||||
def single_motor_calibration(arm: MotorsBus, motor_id: int):
|
||||
"""Calibrates a single motor and returns its calibration data for updating the calibration file."""
|
||||
|
||||
disable_torque(arm)
|
||||
print(f"\n--- Calibrating Motor {motor_id} ---")
|
||||
|
||||
start_pos = 0
|
||||
end_pos = 0
|
||||
encoder_offset = 0
|
||||
|
||||
if motor_id == 6:
|
||||
start_pos, end_pos = calibrate_linear_motor(motor_id, arm)
|
||||
else:
|
||||
input("Move the motor to (zero) position, then press Enter...")
|
||||
encoder_offset = calibrate_homing_motor(motor_id, arm)
|
||||
|
||||
print(f"Calibration for motor ID:{motor_id} done.")
|
||||
|
||||
# Create a calibration dictionary for the single motor
|
||||
calib_dict = {
|
||||
"homing_offset": int(encoder_offset),
|
||||
"drive_mode": 0,
|
||||
"start_pos": int(start_pos),
|
||||
"end_pos": int(end_pos),
|
||||
"calib_mode": get_calibration_modes(arm)[motor_id - 1],
|
||||
"motor_name": arm.motor_names[motor_id - 1],
|
||||
}
|
||||
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_full_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""
|
||||
Runs a full calibration process for all motors in a robotic arm.
|
||||
|
||||
This function calibrates each motor in the arm, determining encoder offsets and
|
||||
start/end positions for linear and rotational motors. The calibration data is then
|
||||
stored in a dictionary for later use.
|
||||
|
||||
**Calibration Process:**
|
||||
- The user is prompted to move the arm to its homing position before starting.
|
||||
- Motors with rotational motion are calibrated using a homing method.
|
||||
- Linear actuators (e.g., grippers) are calibrated separately.
|
||||
- Encoder offsets, start positions, and end positions are recorded.
|
||||
|
||||
**Example Usage:**
|
||||
```python
|
||||
run_full_arm_calibration(arm, "so100", "left", "follower")
|
||||
```
|
||||
"""
|
||||
disable_torque(arm)
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to homing position (middle)")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
|
||||
) # TODO(pepijn): replace with new instruction homing pos (all motors in middle) in tutorial
|
||||
input("Press Enter to continue...")
|
||||
|
||||
start_positions = np.zeros(len(arm.motor_indices))
|
||||
end_positions = np.zeros(len(arm.motor_indices))
|
||||
encoder_offsets = np.zeros(len(arm.motor_indices))
|
||||
|
||||
modes = get_calibration_modes(arm)
|
||||
|
||||
for i, motor_id in enumerate(arm.motor_indices):
|
||||
if modes[i] == CalibrationMode.DEGREE.name:
|
||||
encoder_offsets[i] = calibrate_homing_motor(motor_id, arm)
|
||||
start_positions[i] = 0
|
||||
end_positions[i] = 0
|
||||
|
||||
for i, motor_id in enumerate(arm.motor_indices):
|
||||
if modes[i] == CalibrationMode.LINEAR.name:
|
||||
start_positions[i], end_positions[i] = calibrate_linear_motor(motor_id, arm)
|
||||
encoder_offsets[i] = 0
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
input("Press Enter to continue...")
|
||||
|
||||
print(f"\n calibration of {robot_type} {arm_name} {arm_type} done!")
|
||||
|
||||
# Force drive_mode values (can be static)
|
||||
drive_modes = [0, 1, 0, 0, 1, 0]
|
||||
|
||||
calib_dict = {
|
||||
"homing_offset": encoder_offsets.astype(int).tolist(),
|
||||
"drive_mode": drive_modes,
|
||||
"start_pos": start_positions.astype(int).tolist(),
|
||||
"end_pos": end_positions.astype(int).tolist(),
|
||||
"calib_mode": get_calibration_modes(arm),
|
||||
"motor_names": arm.motor_names,
|
||||
}
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_full_auto_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""TODO(pepijn): Add this method later as extra
|
||||
Example of usage:
|
||||
```python
|
||||
run_full_auto_arm_calibration(arm, "so100", "left", "follower")
|
||||
```
|
||||
"""
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
|
||||
def apply_feetech_offsets_from_calibration(motorsbus: FeetechMotorsBus, calibration_dict: dict):
|
||||
"""
|
||||
Reads 'calibration_dict' containing 'homing_offset' and 'motor_names',
|
||||
then writes each motor's offset to the servo's internal Offset (0x1F) in EPROM.
|
||||
|
||||
This version is modified so each homed position (originally 0) will now read
|
||||
2047, i.e. 180° away from 0 in the 4096-count circle. Offsets are permanently
|
||||
stored in EEPROM, so the servo's Present_Position is hardware-shifted even
|
||||
after power cycling.
|
||||
|
||||
Steps:
|
||||
1) Subtract 2047 from the old offset (so 0 -> 2047).
|
||||
2) Clamp to [-2047..+2047].
|
||||
3) Encode sign bit and magnitude into a 12-bit number.
|
||||
"""
|
||||
|
||||
homing_offsets = calibration_dict["homing_offset"]
|
||||
motor_names = calibration_dict["motor_names"]
|
||||
start_pos = calibration_dict["start_pos"]
|
||||
|
||||
# Open the write lock, changes to EEPROM do NOT persist yet
|
||||
motorsbus.write("Lock", 1)
|
||||
|
||||
# For each motor, set the 'Offset' parameter
|
||||
for m_name, old_offset in zip(motor_names, homing_offsets, strict=False):
|
||||
# If bus doesn’t have a motor named m_name, skip
|
||||
if m_name not in motorsbus.motors:
|
||||
print(f"Warning: '{m_name}' not found in motorsbus.motors; skipping offset.")
|
||||
continue
|
||||
|
||||
if m_name == "gripper":
|
||||
old_offset = start_pos # If gripper set the offset to the start position of the gripper
|
||||
continue
|
||||
|
||||
# Shift the offset so the homed position reads 2047
|
||||
new_offset = old_offset - 2047
|
||||
|
||||
# Clamp to [-2047..+2047]
|
||||
if new_offset > 2047:
|
||||
new_offset = 2047
|
||||
print(
|
||||
f"Warning: '{new_offset}' is getting clamped because its larger then 2047; This should not happen!"
|
||||
)
|
||||
elif new_offset < -2047:
|
||||
new_offset = -2047
|
||||
print(
|
||||
f"Warning: '{new_offset}' is getting clamped because its smaller then -2047; This should not happen!"
|
||||
)
|
||||
|
||||
# Determine the direction (sign) bit and magnitude
|
||||
direction_bit = 1 if new_offset < 0 else 0
|
||||
magnitude = abs(new_offset)
|
||||
|
||||
# Combine sign bit (bit 11) with the magnitude (bits 0..10)
|
||||
servo_offset = (direction_bit << 11) | magnitude
|
||||
|
||||
# Write offset to servo
|
||||
motorsbus.write("Offset", servo_offset, motor_names=m_name)
|
||||
print(
|
||||
f"Set offset for {m_name}: "
|
||||
f"old_offset={old_offset}, new_offset={new_offset}, servo_encoded={magnitude} + direction={direction_bit}"
|
||||
)
|
||||
|
||||
motorsbus.write("Lock", 0)
|
||||
print("Offsets have been saved to EEPROM successfully.")
|
||||
46
lerobot/common/motors/motors_bus.py
Normal file
46
lerobot/common/motors/motors_bus.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import abc
|
||||
|
||||
|
||||
class MotorsBus(abc.ABC):
|
||||
"""The main LeRobot class for implementing motors buses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
):
|
||||
self.motors = motors
|
||||
|
||||
def __len__(self):
|
||||
return len(self.motors)
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def reconnect(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def apply_calibration(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def revert_calibration(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self):
|
||||
pass
|
||||
56
lerobot/common/motors/utils.py
Normal file
56
lerobot/common/motors/utils.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configs import MotorsBusConfig
|
||||
from .motors_bus import MotorsBus
|
||||
|
||||
|
||||
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
|
||||
motors_buses = {}
|
||||
|
||||
for key, cfg in motors_bus_configs.items():
|
||||
if cfg.type == "dynamixel":
|
||||
from .dynamixel import DynamixelMotorsBus
|
||||
|
||||
motors_buses[key] = DynamixelMotorsBus(cfg)
|
||||
|
||||
elif cfg.type == "feetech":
|
||||
from lerobot.common.motors.feetech.feetech import FeetechMotorsBus
|
||||
|
||||
motors_buses[key] = FeetechMotorsBus(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
return motors_buses
|
||||
|
||||
|
||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
if motor_type == "dynamixel":
|
||||
from .configs import DynamixelMotorsBusConfig
|
||||
from .dynamixel import DynamixelMotorsBus
|
||||
|
||||
config = DynamixelMotorsBusConfig(**kwargs)
|
||||
return DynamixelMotorsBus(config)
|
||||
|
||||
elif motor_type == "feetech":
|
||||
from feetech import FeetechMotorsBus
|
||||
|
||||
from .configs import FeetechMotorsBusConfig
|
||||
|
||||
config = FeetechMotorsBusConfig(**kwargs)
|
||||
return FeetechMotorsBus(config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||
15
lerobot/common/optim/__init__.py
Normal file
15
lerobot/common/optim/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .optimizers import OptimizerConfig as OptimizerConfig
|
||||
40
lerobot/common/optim/factory.py
Normal file
40
lerobot/common/optim/factory.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def make_optimizer_and_scheduler(
|
||||
cfg: TrainPipelineConfig, policy: PreTrainedPolicy
|
||||
) -> tuple[Optimizer, LRScheduler | None]:
|
||||
"""Generates the optimizer and scheduler based on configs.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs
|
||||
policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from.
|
||||
|
||||
Returns:
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
118
lerobot/common/optim/optimizers.py
Normal file
118
lerobot/common/optim/optimizers.py
Normal file
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
)
|
||||
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
lr: float
|
||||
weight_decay: float
|
||||
grad_clip_norm: float
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@classmethod
|
||||
def default_choice_name(cls) -> str | None:
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("adam")
|
||||
@dataclass
|
||||
class AdamConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
betas: tuple[float, float] = (0.9, 0.999)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.Adam(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("adamw")
|
||||
@dataclass
|
||||
class AdamWConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
betas: tuple[float, float] = (0.9, 0.999)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.AdamW(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("sgd")
|
||||
@dataclass
|
||||
class SGDConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
momentum: float = 0.0
|
||||
dampening: float = 0.0
|
||||
nesterov: bool = False
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
state = optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
||||
if "param_groups" in current_state_dict:
|
||||
param_groups = deserialize_json_into_object(
|
||||
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
|
||||
)
|
||||
loaded_state_dict["param_groups"] = param_groups
|
||||
|
||||
optimizer.load_state_dict(loaded_state_dict)
|
||||
return optimizer
|
||||
122
lerobot/common/optim/schedulers.py
Normal file
122
lerobot/common/optim/schedulers.py
Normal file
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.common.constants import SCHEDULER_STATE
|
||||
from lerobot.common.datasets.utils import write_json
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
num_warmup_steps: int
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("diffuser")
|
||||
@dataclass
|
||||
class DiffuserSchedulerConfig(LRSchedulerConfig):
|
||||
name: str = "cosine"
|
||||
num_warmup_steps: int | None = None
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||
return get_scheduler(**kwargs)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("vqbet")
|
||||
@dataclass
|
||||
class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
num_warmup_steps: int
|
||||
num_vqvae_training_steps: int
|
||||
num_cycles: float = 0.5
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step):
|
||||
if current_step < self.num_vqvae_training_steps:
|
||||
return float(1)
|
||||
else:
|
||||
adjusted_step = current_step - self.num_vqvae_training_steps
|
||||
if adjusted_step < self.num_warmup_steps:
|
||||
return float(adjusted_step) / float(max(1, self.num_warmup_steps))
|
||||
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
||||
max(1, num_training_steps - self.num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by Physical Intelligence to train Pi0"""
|
||||
|
||||
num_warmup_steps: int
|
||||
num_decay_steps: int
|
||||
peak_lr: float
|
||||
decay_lr: float
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
del num_training_steps
|
||||
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
return 1 / (self.num_warmup_steps + 1)
|
||||
frac = 1 - current_step / self.num_warmup_steps
|
||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
|
||||
if current_step < self.num_warmup_steps:
|
||||
return linear_warmup_schedule(current_step)
|
||||
|
||||
return cosine_decay_schedule(current_step)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
state_dict = scheduler.state_dict()
|
||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||
|
||||
|
||||
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
||||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
19
lerobot/common/policies/__init__.py
Normal file
19
lerobot/common/policies/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
@@ -15,9 +15,14 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("act")
|
||||
@dataclass
|
||||
class ACTConfig:
|
||||
class ACTConfig(PreTrainedConfig):
|
||||
"""Configuration class for the Action Chunking Transformers policy.
|
||||
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
@@ -59,7 +64,7 @@ class ACTConfig:
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||
convolution.
|
||||
@@ -90,28 +95,11 @@ class ACTConfig:
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "mean_std",
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -144,7 +132,14 @@ class ACTConfig:
|
||||
dropout: float = 0.1
|
||||
kl_weight: float = 10.0
|
||||
|
||||
# Training preset
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_lr_backbone: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
@@ -164,10 +159,28 @@ class ACTConfig:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if (
|
||||
not any(k.startswith("observation.image") for k in self.input_shapes)
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -29,32 +29,27 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class ACTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "act"],
|
||||
):
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
"""
|
||||
|
||||
config_class = ACTConfig
|
||||
name = "act"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig | None = None,
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -64,34 +59,46 @@ class ACTPolicy(
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = ACTConfig()
|
||||
self.config: ACTConfig = config
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
self.expected_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(
|
||||
config.temporal_ensemble_coeff, config.chunk_size
|
||||
)
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
|
||||
# Should we remove this and just `return self.parameters()`?
|
||||
return [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": self.config.optimizer_lr_backbone,
|
||||
},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
@@ -110,12 +117,10 @@ class ACTPolicy(
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
@@ -139,22 +144,19 @@ class ACTPolicy(
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
* ~batch["action_is_pad"].unsqueeze(-1)
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
@@ -164,19 +166,14 @@ class ACTPolicy(
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(
|
||||
-0.5
|
||||
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
|
||||
)
|
||||
.sum(-1)
|
||||
.mean()
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
loss = l1_loss
|
||||
|
||||
return loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
@@ -223,9 +220,7 @@ class ACTTemporalEnsembler:
|
||||
```
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(
|
||||
-temporal_ensemble_coeff * torch.arange(chunk_size)
|
||||
)
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.reset()
|
||||
|
||||
@@ -241,9 +236,7 @@ class ACTTemporalEnsembler:
|
||||
time steps, and pop/return the next batch of actions in the sequence.
|
||||
"""
|
||||
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
|
||||
device=actions.device
|
||||
)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
||||
if self.ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
@@ -251,34 +244,19 @@ class ACTTemporalEnsembler:
|
||||
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||
# operations later.
|
||||
self.ensembled_actions_count = torch.ones(
|
||||
(self.chunk_size, 1),
|
||||
dtype=torch.long,
|
||||
device=self.ensembled_actions.device,
|
||||
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||
)
|
||||
else:
|
||||
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the online update for those entries.
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[
|
||||
self.ensembled_actions_count - 1
|
||||
]
|
||||
self.ensembled_actions += (
|
||||
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
)
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[
|
||||
self.ensembled_actions_count
|
||||
]
|
||||
self.ensembled_actions_count = torch.clamp(
|
||||
self.ensembled_actions_count + 1, max=self.chunk_size
|
||||
)
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
||||
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
||||
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
||||
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||
self.ensembled_actions = torch.cat(
|
||||
[self.ensembled_actions, actions[:, -1:]], dim=1
|
||||
)
|
||||
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||
self.ensembled_actions_count = torch.cat(
|
||||
[
|
||||
self.ensembled_actions_count,
|
||||
torch.ones_like(self.ensembled_actions_count[-1:]),
|
||||
]
|
||||
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
||||
)
|
||||
# "Consume" the first action.
|
||||
action, self.ensembled_actions, self.ensembled_actions_count = (
|
||||
@@ -325,60 +303,47 @@ class ACT(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
self.use_robot_state = "observation.state" in config.input_shapes
|
||||
self.use_images = any(
|
||||
k.startswith("observation.image") for k in config.input_shapes
|
||||
)
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(
|
||||
config.output_shapes["action"][0], config.dim_model
|
||||
self.config.action_feature.shape[0],
|
||||
config.dim_model,
|
||||
)
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(
|
||||
config.dim_model, config.latent_dim * 2
|
||||
)
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
create_sinusoidal_pos_embedding(
|
||||
num_input_token_encoder, config.dim_model
|
||||
).unsqueeze(0),
|
||||
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[
|
||||
False,
|
||||
False,
|
||||
config.replace_final_stride_with_dilation,
|
||||
],
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
# feature map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(
|
||||
backbone_model, return_layers={"layer4": "feature_map"}
|
||||
)
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
@@ -386,40 +351,35 @@ class ACT(nn.Module):
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
)
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
self.encoder_env_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0],
|
||||
config.dim_model,
|
||||
self.config.env_state_feature.shape[0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
n_1d_tokens = 1 # for the latent
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
n_1d_tokens += 1
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.use_images:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
|
||||
config.dim_model // 2
|
||||
)
|
||||
if self.config.image_features:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(
|
||||
config.dim_model, config.output_shapes["action"][0]
|
||||
)
|
||||
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
@@ -429,20 +389,18 @@ class ACT(nn.Module):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||
|
||||
`batch` should have the following structure:
|
||||
{
|
||||
"observation.state" (optional): (B, state_dim) batch of robot states.
|
||||
[robot_state_feature] (optional): (B, state_dim) batch of robot states.
|
||||
|
||||
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
||||
[image_features]: (B, n_cameras, C, H, W) batch of images.
|
||||
AND/OR
|
||||
"observation.environment_state": (B, env_dim) batch of environment states.
|
||||
[env_state_feature]: (B, env_dim) batch of environment states.
|
||||
|
||||
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
[action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
}
|
||||
|
||||
Returns:
|
||||
@@ -451,9 +409,9 @@ class ACT(nn.Module):
|
||||
latent dimension.
|
||||
"""
|
||||
if self.config.use_vae and self.training:
|
||||
assert (
|
||||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
assert "action" in batch, (
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
batch_size = (
|
||||
batch["observation.images"]
|
||||
@@ -467,21 +425,13 @@ class ACT(nn.Module):
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.use_robot_state:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(
|
||||
batch["observation.state"]
|
||||
)
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(
|
||||
batch["action"]
|
||||
) # (B, S, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
if self.use_robot_state:
|
||||
vae_encoder_input = [
|
||||
cls_embed,
|
||||
robot_state_embed,
|
||||
action_embed,
|
||||
] # (B, S+2, D)
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||
@@ -494,7 +444,7 @@ class ACT(nn.Module):
|
||||
# sequence depending whether we use the input states or not (cls and robot state)
|
||||
# False means not a padding token.
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.use_robot_state else 1),
|
||||
(batch_size, 2 if self.config.robot_state_feature else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
)
|
||||
@@ -519,57 +469,41 @@ class ACT(nn.Module):
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros(
|
||||
[batch_size, self.config.latent_dim], dtype=torch.float32
|
||||
).to(batch["observation.state"].device)
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
)
|
||||
|
||||
# Prepare transformer encoder inputs.
|
||||
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
||||
encoder_in_pos_embed = list(
|
||||
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
|
||||
)
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.use_robot_state:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||
)
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
# Environment state token.
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(
|
||||
batch["observation.environment_state"]
|
||||
)
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
|
||||
# Camera observation features and positional embeddings.
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
|
||||
for cam_index in range(batch["observation.images"].shape[-4]):
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])[
|
||||
"feature_map"
|
||||
]
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
||||
# buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
|
||||
dtype=cam_features.dtype
|
||||
)
|
||||
cam_features = self.encoder_img_feat_input_proj(
|
||||
cam_features
|
||||
) # (B, C, h, w)
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
||||
# and move to (sequence, batch, dim).
|
||||
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
||||
encoder_in_tokens.extend(
|
||||
einops.rearrange(all_cam_features, "b c h w -> (h w) b c")
|
||||
)
|
||||
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
||||
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
encoder_in_pos_embed.extend(
|
||||
einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c")
|
||||
)
|
||||
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
@@ -604,21 +538,12 @@ class ACTEncoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
|
||||
super().__init__()
|
||||
self.is_vae_encoder = is_vae_encoder
|
||||
num_layers = (
|
||||
config.n_vae_encoder_layers
|
||||
if self.is_vae_encoder
|
||||
else config.n_encoder_layers
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[ACTEncoderLayer(config) for _ in range(num_layers)]
|
||||
)
|
||||
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
|
||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_embed: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
|
||||
@@ -629,9 +554,7 @@ class ACTEncoder(nn.Module):
|
||||
class ACTEncoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
@@ -646,9 +569,7 @@ class ACTEncoderLayer(nn.Module):
|
||||
self.activation = get_activation_fn(config.feedforward_activation)
|
||||
self.pre_norm = config.pre_norm
|
||||
|
||||
def forward(
|
||||
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
) -> Tensor:
|
||||
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
|
||||
skip = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
@@ -673,9 +594,7 @@ class ACTDecoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
|
||||
)
|
||||
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
def forward(
|
||||
@@ -687,10 +606,7 @@ class ACTDecoder(nn.Module):
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(
|
||||
x,
|
||||
encoder_out,
|
||||
decoder_pos_embed=decoder_pos_embed,
|
||||
encoder_pos_embed=encoder_pos_embed,
|
||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||
)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
@@ -700,12 +616,8 @@ class ACTDecoder(nn.Module):
|
||||
class ACTDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
@@ -746,9 +658,7 @@ class ACTDecoderLayer(nn.Module):
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
||||
x = self.self_attn(q, k, value=x)[
|
||||
0
|
||||
] # select just the output, not the attention weights
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
@@ -785,14 +695,9 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
|
||||
"""
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / dimension)
|
||||
for hid_j in range(dimension)
|
||||
]
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
|
||||
)
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
return torch.from_numpy(sinusoid_table).float()
|
||||
@@ -837,9 +742,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
||||
|
||||
inverse_frequency = self._temperature ** (
|
||||
2
|
||||
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
|
||||
/ self.dimension
|
||||
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
||||
)
|
||||
|
||||
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||
@@ -847,15 +750,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
|
||||
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
|
||||
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
|
||||
pos_embed_x = torch.stack(
|
||||
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
|
||||
).flatten(3)
|
||||
pos_embed_y = torch.stack(
|
||||
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
|
||||
).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
|
||||
0, 3, 1, 2
|
||||
) # (1, C, H, W)
|
||||
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
@@ -16,9 +16,15 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("diffusion")
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
class DiffusionConfig(PreTrainedConfig):
|
||||
"""Configuration class for DiffusionPolicy.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
@@ -62,7 +68,7 @@ class DiffusionConfig:
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
@@ -93,7 +99,7 @@ class DiffusionConfig:
|
||||
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
|
||||
to False as the original Diffusion Policy implementation does the same.
|
||||
"""
|
||||
|
||||
@@ -102,28 +108,17 @@ class DiffusionConfig:
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"}
|
||||
)
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -156,44 +151,23 @@ class DiffusionConfig:
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-6
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
|
||||
if (
|
||||
len(image_keys) == 0
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
if len(image_keys) > 0:
|
||||
if self.crop_shape is not None:
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
raise ValueError(
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
)
|
||||
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
raise ValueError(
|
||||
@@ -214,3 +188,50 @@ class DiffusionConfig:
|
||||
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -31,35 +31,32 @@ import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
|
||||
class DiffusionPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "diffusion-policy"],
|
||||
):
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
"""
|
||||
|
||||
config_class = DiffusionConfig
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig | None = None,
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -69,18 +66,16 @@ class DiffusionPolicy(
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = DiffusionConfig()
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
@@ -88,25 +83,21 @@ class DiffusionPolicy(
|
||||
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
self.expected_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.diffusion.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(
|
||||
maxlen=self.config.n_obs_steps
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -131,23 +122,17 @@ class DiffusionPolicy(
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {
|
||||
k: torch.stack(list(self._queues[k]), dim=1)
|
||||
for k in batch
|
||||
if k in self._queues
|
||||
}
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
@@ -158,19 +143,18 @@ class DiffusionPolicy(
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
@@ -192,14 +176,9 @@ class DiffusionModel(nn.Module):
|
||||
self.config = config
|
||||
|
||||
# Build observation encoders (depending on which observations are provided).
|
||||
global_cond_dim = config.input_shapes["observation.state"][0]
|
||||
num_images = len(
|
||||
[k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
)
|
||||
self._use_images = False
|
||||
self._use_env_state = False
|
||||
if num_images > 0:
|
||||
self._use_images = True
|
||||
global_cond_dim = self.config.robot_state_feature.shape[0]
|
||||
if self.config.image_features:
|
||||
num_images = len(self.config.image_features)
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||
self.rgb_encoder = nn.ModuleList(encoders)
|
||||
@@ -207,13 +186,10 @@ class DiffusionModel(nn.Module):
|
||||
else:
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||
if self.config.env_state_feature:
|
||||
global_cond_dim += self.config.env_state_feature.shape[0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(
|
||||
config, global_cond_dim=global_cond_dim * config.n_obs_steps
|
||||
)
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
@@ -233,21 +209,14 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self,
|
||||
batch_size: int,
|
||||
global_cond: Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
||||
) -> Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(
|
||||
batch_size,
|
||||
self.config.horizon,
|
||||
self.config.output_shapes["action"][0],
|
||||
),
|
||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -263,58 +232,44 @@ class DiffusionModel(nn.Module):
|
||||
global_cond=global_cond,
|
||||
)
|
||||
# Compute previous image: x_t -> x_t-1
|
||||
sample = self.noise_scheduler.step(
|
||||
model_output, t, sample, generator=generator
|
||||
).prev_sample
|
||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
global_cond_feats = [batch["observation.state"]]
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
global_cond_feats = [batch[OBS_STATE]]
|
||||
# Extract image features.
|
||||
if self._use_images:
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(
|
||||
batch["observation.images"], "b s n ... -> n (b s) ..."
|
||||
)
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
for encoder, images in zip(
|
||||
self.rgb_encoder, images_per_camera, strict=True
|
||||
)
|
||||
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||
]
|
||||
)
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features_list,
|
||||
"(n b s) ... -> b s (n ...)",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(
|
||||
batch["observation.images"], "b s n ... -> (b s n) ..."
|
||||
)
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features,
|
||||
"(b s n) ... -> b s (n ...)",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self._use_env_state:
|
||||
global_cond_feats.append(batch["observation.environment_state"])
|
||||
if self.config.env_state_feature:
|
||||
global_cond_feats.append(batch[OBS_ENV_STATE])
|
||||
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
@@ -395,9 +350,7 @@ class DiffusionModel(nn.Module):
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported prediction type {self.config.prediction_type}"
|
||||
)
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
@@ -457,9 +410,7 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(
|
||||
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
|
||||
)
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
@@ -488,7 +439,7 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
|
||||
class DiffusionRgbEncoder(nn.Module):
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
"""Encodes an RGB image into a 1D feature vector.
|
||||
|
||||
Includes the ability to normalize and crop the image first.
|
||||
"""
|
||||
@@ -501,9 +452,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(
|
||||
config.crop_shape
|
||||
)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -524,35 +473,22 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape
|
||||
if config.crop_shape is not None
|
||||
else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(
|
||||
size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)
|
||||
)
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -579,9 +515,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module,
|
||||
predicate: Callable[[nn.Module], bool],
|
||||
func: Callable[[nn.Module], nn.Module],
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
@@ -594,11 +528,7 @@ def _replace_submodules(
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [
|
||||
k.split(".")
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
@@ -613,9 +543,7 @@ def _replace_submodules(
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(
|
||||
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
|
||||
)
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
||||
|
||||
|
||||
@@ -643,9 +571,7 @@ class DiffusionConv1dBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
|
||||
),
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
@@ -668,13 +594,9 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
# Encoder for the diffusion timestep.
|
||||
self.diffusion_step_encoder = nn.Sequential(
|
||||
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
|
||||
nn.Linear(
|
||||
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
|
||||
),
|
||||
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(
|
||||
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
|
||||
),
|
||||
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
|
||||
)
|
||||
|
||||
# The FiLM conditioning dimension.
|
||||
@@ -682,7 +604,7 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
|
||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||
# just reverse these.
|
||||
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
|
||||
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
|
||||
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
|
||||
)
|
||||
|
||||
@@ -699,16 +621,10 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
self.down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_in, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_out, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Downsample as long as it is not the last block.
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
|
||||
if not is_last
|
||||
else nn.Identity(),
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -717,14 +633,10 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1],
|
||||
config.down_dims[-1],
|
||||
**common_res_block_kwargs,
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1],
|
||||
config.down_dims[-1],
|
||||
**common_res_block_kwargs,
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
]
|
||||
)
|
||||
@@ -737,25 +649,17 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
nn.ModuleList(
|
||||
[
|
||||
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_in * 2, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_out, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Upsample as long as it is not the last block.
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
|
||||
if not is_last
|
||||
else nn.Identity(),
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(
|
||||
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
|
||||
),
|
||||
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||
@@ -822,23 +726,17 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||
self.use_film_scale_modulation = use_film_scale_modulation
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.conv1 = DiffusionConv1dBlock(
|
||||
in_channels, out_channels, kernel_size, n_groups=n_groups
|
||||
)
|
||||
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
self.conv2 = DiffusionConv1dBlock(
|
||||
out_channels, out_channels, kernel_size, n_groups=n_groups
|
||||
)
|
||||
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# A final convolution for dimension matching the residual (if needed).
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||
|
||||
@@ -13,114 +13,138 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
|
||||
import logging
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch import nn
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
|
||||
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
|
||||
logging.warning(
|
||||
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||
)
|
||||
|
||||
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
|
||||
# issues with mutable defaults. This filter changes all lists to tuples.
|
||||
def list_to_tuple(item):
|
||||
return tuple(item) if isinstance(item, list) else item
|
||||
|
||||
policy_cfg = policy_cfg_class(
|
||||
**{
|
||||
k: list_to_tuple(v)
|
||||
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
|
||||
if k in expected_kwargs
|
||||
}
|
||||
)
|
||||
return policy_cfg
|
||||
|
||||
|
||||
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
return TDMPCPolicy
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import (
|
||||
DiffusionConfig,
|
||||
)
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
return DiffusionPolicy, DiffusionConfig
|
||||
return DiffusionPolicy
|
||||
elif name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
return ACTPolicy, ACTConfig
|
||||
return ACTPolicy
|
||||
elif name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy, VQBeTConfig
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
return VQBeTPolicy
|
||||
elif name == "pi0":
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return SACPolicy, SACConfig
|
||||
return PI0Policy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
if policy_type == "tdmpc":
|
||||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
def make_policy(
|
||||
hydra_cfg: DictConfig,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
dataset_stats=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Policy:
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
This function exists because (for now) we need to parse features from either a dataset or an environment
|
||||
in order to properly dimension and instantiate a policy for that dataset or environment.
|
||||
|
||||
Args:
|
||||
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is
|
||||
provided, only `hydra_cfg.policy.name` is used while everything else is ignored.
|
||||
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a
|
||||
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
|
||||
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`.
|
||||
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must
|
||||
be provided when initializing a new policy, and must not be provided when loading a pretrained
|
||||
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||
be loaded with the weights from that path.
|
||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||
provided if ds_meta is not. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
|
||||
|
||||
Returns:
|
||||
PreTrainedPolicy: _description_
|
||||
"""
|
||||
# if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
# raise ValueError(
|
||||
# "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
# )
|
||||
if bool(ds_meta) == bool(env_cfg):
|
||||
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
||||
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
if pretrained_policy_name_or_path is None:
|
||||
# Make a fresh policy.
|
||||
# HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments
|
||||
# for example device for the sac policy.
|
||||
policy = policy_cls(config=policy_cfg, dataset_stats=dataset_stats)
|
||||
else:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
|
||||
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
|
||||
# huggingface_hub should make it possible to avoid the hack:
|
||||
# https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.load_state_dict(
|
||||
policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()
|
||||
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
|
||||
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
|
||||
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
|
||||
# you want this op to be added in priority during the prototype phase of this feature, please comment on
|
||||
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
||||
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
||||
# slower than running natively on MPS.
|
||||
if cfg.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
policy_cls = get_policy_class(cfg.type)
|
||||
|
||||
kwargs = {}
|
||||
if ds_meta is not None:
|
||||
features = dataset_to_policy_features(ds_meta.features)
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
else:
|
||||
if not cfg.pretrained_path:
|
||||
logging.warning(
|
||||
"You are instantiating a policy from scratch and its features are parsed from an environment "
|
||||
"rather than a dataset. Normalization modules inside the policy will have infinite values "
|
||||
"by default without stats from a dataset."
|
||||
)
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(**kwargs)
|
||||
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
return policy
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user