Compare commits
152 Commits
Cadene-pat
...
fix/lint_w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e511e7eda5 | ||
|
|
32fffd4bbb | ||
|
|
03c7cf8a63 | ||
|
|
f5ed3723f0 | ||
|
|
b104be0d04 | ||
|
|
f9e4a1f5c4 | ||
|
|
0eb56cec14 | ||
|
|
e59ef036e1 | ||
|
|
9b380eaf67 | ||
|
|
1187604ba0 | ||
|
|
5c6f2d2cd0 | ||
|
|
652fedf69c | ||
|
|
85214ec303 | ||
|
|
dffa5a18db | ||
|
|
301f152a34 | ||
|
|
0ed08c0b1f | ||
|
|
254bc707e7 | ||
|
|
074f0ac8fe | ||
|
|
25c63ccf63 | ||
|
|
5e9473806c | ||
|
|
10706ed753 | ||
|
|
0b8205a8a0 | ||
|
|
57ae509823 | ||
|
|
5d24ce3160 | ||
|
|
d694ea1d38 | ||
|
|
a00936686f | ||
|
|
2feb5edc65 | ||
|
|
b80e55ca44 | ||
|
|
e8ce388109 | ||
|
|
a4c1da25de | ||
|
|
a003e7c081 | ||
|
|
a27411022d | ||
|
|
3827974b58 | ||
|
|
b299cfea8a | ||
|
|
bf6f89a5b5 | ||
|
|
8861546ad8 | ||
|
|
9c1a893ee3 | ||
|
|
e81c36cf74 | ||
|
|
ed83cbd4f2 | ||
|
|
2a33b9ad87 | ||
|
|
6e85aa13ec | ||
|
|
af05a1725c | ||
|
|
800c4a847f | ||
|
|
bba8c4c0d4 | ||
|
|
68b369e321 | ||
|
|
8d60ac3ffc | ||
|
|
659ec4434d | ||
|
|
da265ca920 | ||
|
|
a1809ad3de | ||
|
|
8699a28be0 | ||
|
|
65db5afe1c | ||
|
|
75d5fa4604 | ||
|
|
e64fad2224 | ||
|
|
eecf32e77a | ||
|
|
3354d919fc | ||
|
|
aca464ca72 | ||
|
|
fe483b1d0d | ||
|
|
ddeade077e | ||
|
|
c4c2ce04e7 | ||
|
|
2cb0bf5d41 | ||
|
|
b86a2c0b47 | ||
|
|
c574eb4984 | ||
|
|
1e49cc4d60 | ||
|
|
e71095960f | ||
|
|
90e099b39f | ||
|
|
334deb985d | ||
|
|
8548a87bd4 | ||
|
|
638d411cd3 | ||
|
|
dd974529cf | ||
|
|
43e079f73e | ||
|
|
6674e36824 | ||
|
|
ae9605f03c | ||
|
|
3c0a209f9f | ||
|
|
1ee1acf8ad | ||
|
|
c4d912a241 | ||
|
|
4323bdce22 | ||
|
|
5daa45436d | ||
|
|
4def6d6ac2 | ||
|
|
d8560b8d5f | ||
|
|
380b836eee | ||
|
|
eec6796cb8 | ||
|
|
25a8597680 | ||
|
|
b8b368310c | ||
|
|
5097cd900e | ||
|
|
bc16e1b497 | ||
|
|
8f821ecad0 | ||
|
|
4519016e67 | ||
|
|
59e2757434 | ||
|
|
73b64c3089 | ||
|
|
66f8736598 | ||
|
|
4c41f6fcc6 | ||
|
|
44f9b21e74 | ||
|
|
03f49ceaf0 | ||
|
|
8e7d6970ea | ||
|
|
286bca37cc | ||
|
|
a2c181992a | ||
|
|
32eb0cec8f | ||
|
|
96c7052777 | ||
|
|
975c1c25c3 | ||
|
|
20f466768e | ||
|
|
8af693548e | ||
|
|
963738d983 | ||
|
|
e0df56de62 | ||
|
|
538455a965 | ||
|
|
172809a502 | ||
|
|
55e4ff6742 | ||
|
|
07e8716315 | ||
|
|
114870d703 | ||
|
|
2efee45ef1 | ||
|
|
c351e1fff9 | ||
|
|
cd0fc261c0 | ||
|
|
77478d50e5 | ||
|
|
97b1feb0b3 | ||
|
|
c29e70e5a1 | ||
|
|
d5b669634a | ||
|
|
1a343c3591 | ||
|
|
26f97cfd17 | ||
|
|
72f402d44b | ||
|
|
92573486a8 | ||
|
|
c712d68f6a | ||
|
|
f431a08efa | ||
|
|
beaa427504 | ||
|
|
a88dd602d9 | ||
|
|
6c0324f467 | ||
|
|
a60d27b132 | ||
|
|
9c463661c1 | ||
|
|
4255655618 | ||
|
|
f17d9a2ba1 | ||
|
|
9ff829a3a1 | ||
|
|
d6516f0e03 | ||
|
|
b0b8612eff | ||
|
|
1072a055db | ||
|
|
9c9f5cac90 | ||
|
|
9d0c6fe419 | ||
|
|
54ac25cfc9 | ||
|
|
150a292795 | ||
|
|
429a463aff | ||
|
|
27ba2951d1 | ||
|
|
b2896d38f5 | ||
|
|
c0da806232 | ||
|
|
114e09f570 | ||
|
|
04a995e7d1 | ||
|
|
4806336816 | ||
|
|
1ce418e4a1 | ||
|
|
eb4c505cff | ||
|
|
aad59e6b6b | ||
|
|
9ce98bb93c | ||
|
|
97086cdcdf | ||
|
|
9c7649f140 | ||
|
|
a2592a5563 | ||
|
|
b5ad79a7d3 | ||
|
|
996468bcce |
68
.cache/calibration/aloha_default/left_follower.json
Normal file
68
.cache/calibration/aloha_default/left_follower.json
Normal file
@@ -0,0 +1,68 @@
|
||||
{
|
||||
"homing_offset": [
|
||||
2048,
|
||||
3072,
|
||||
3072,
|
||||
-1024,
|
||||
-1024,
|
||||
2048,
|
||||
-2048,
|
||||
2048,
|
||||
-2048
|
||||
],
|
||||
"drive_mode": [
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0
|
||||
],
|
||||
"start_pos": [
|
||||
2015,
|
||||
3058,
|
||||
3061,
|
||||
1071,
|
||||
1071,
|
||||
2035,
|
||||
2152,
|
||||
2029,
|
||||
2499
|
||||
],
|
||||
"end_pos": [
|
||||
-1008,
|
||||
-1963,
|
||||
-1966,
|
||||
2141,
|
||||
2143,
|
||||
-971,
|
||||
3043,
|
||||
-1077,
|
||||
3144
|
||||
],
|
||||
"calib_mode": [
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"LINEAR"
|
||||
],
|
||||
"motor_names": [
|
||||
"waist",
|
||||
"shoulder",
|
||||
"shoulder_shadow",
|
||||
"elbow",
|
||||
"elbow_shadow",
|
||||
"forearm_roll",
|
||||
"wrist_angle",
|
||||
"wrist_rotate",
|
||||
"gripper"
|
||||
]
|
||||
}
|
||||
68
.cache/calibration/aloha_default/left_leader.json
Normal file
68
.cache/calibration/aloha_default/left_leader.json
Normal file
@@ -0,0 +1,68 @@
|
||||
{
|
||||
"homing_offset": [
|
||||
2048,
|
||||
3072,
|
||||
3072,
|
||||
-1024,
|
||||
-1024,
|
||||
2048,
|
||||
-2048,
|
||||
2048,
|
||||
-1024
|
||||
],
|
||||
"drive_mode": [
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0
|
||||
],
|
||||
"start_pos": [
|
||||
2035,
|
||||
3024,
|
||||
3019,
|
||||
979,
|
||||
981,
|
||||
1982,
|
||||
2166,
|
||||
2124,
|
||||
1968
|
||||
],
|
||||
"end_pos": [
|
||||
-990,
|
||||
-2017,
|
||||
-2015,
|
||||
2078,
|
||||
2076,
|
||||
-1030,
|
||||
3117,
|
||||
-1016,
|
||||
2556
|
||||
],
|
||||
"calib_mode": [
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"LINEAR"
|
||||
],
|
||||
"motor_names": [
|
||||
"waist",
|
||||
"shoulder",
|
||||
"shoulder_shadow",
|
||||
"elbow",
|
||||
"elbow_shadow",
|
||||
"forearm_roll",
|
||||
"wrist_angle",
|
||||
"wrist_rotate",
|
||||
"gripper"
|
||||
]
|
||||
}
|
||||
68
.cache/calibration/aloha_default/right_follower.json
Normal file
68
.cache/calibration/aloha_default/right_follower.json
Normal file
@@ -0,0 +1,68 @@
|
||||
{
|
||||
"homing_offset": [
|
||||
2048,
|
||||
3072,
|
||||
3072,
|
||||
-1024,
|
||||
-1024,
|
||||
2048,
|
||||
-2048,
|
||||
2048,
|
||||
-2048
|
||||
],
|
||||
"drive_mode": [
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0
|
||||
],
|
||||
"start_pos": [
|
||||
2056,
|
||||
2895,
|
||||
2896,
|
||||
1191,
|
||||
1190,
|
||||
2018,
|
||||
2051,
|
||||
2056,
|
||||
2509
|
||||
],
|
||||
"end_pos": [
|
||||
-1040,
|
||||
-2004,
|
||||
-2006,
|
||||
2126,
|
||||
2127,
|
||||
-1010,
|
||||
3050,
|
||||
-1117,
|
||||
3143
|
||||
],
|
||||
"calib_mode": [
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"LINEAR"
|
||||
],
|
||||
"motor_names": [
|
||||
"waist",
|
||||
"shoulder",
|
||||
"shoulder_shadow",
|
||||
"elbow",
|
||||
"elbow_shadow",
|
||||
"forearm_roll",
|
||||
"wrist_angle",
|
||||
"wrist_rotate",
|
||||
"gripper"
|
||||
]
|
||||
}
|
||||
68
.cache/calibration/aloha_default/right_leader.json
Normal file
68
.cache/calibration/aloha_default/right_leader.json
Normal file
@@ -0,0 +1,68 @@
|
||||
{
|
||||
"homing_offset": [
|
||||
2048,
|
||||
3072,
|
||||
3072,
|
||||
-1024,
|
||||
-1024,
|
||||
2048,
|
||||
-2048,
|
||||
2048,
|
||||
-2048
|
||||
],
|
||||
"drive_mode": [
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0
|
||||
],
|
||||
"start_pos": [
|
||||
2068,
|
||||
3034,
|
||||
3030,
|
||||
1038,
|
||||
1041,
|
||||
1991,
|
||||
1948,
|
||||
2090,
|
||||
1985
|
||||
],
|
||||
"end_pos": [
|
||||
-1025,
|
||||
-2014,
|
||||
-2015,
|
||||
2058,
|
||||
2060,
|
||||
-955,
|
||||
3091,
|
||||
-940,
|
||||
2576
|
||||
],
|
||||
"calib_mode": [
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"LINEAR"
|
||||
],
|
||||
"motor_names": [
|
||||
"waist",
|
||||
"shoulder",
|
||||
"shoulder_shadow",
|
||||
"elbow",
|
||||
"elbow_shadow",
|
||||
"forearm_roll",
|
||||
"wrist_angle",
|
||||
"wrist_rotate",
|
||||
"gripper"
|
||||
]
|
||||
}
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Misc
|
||||
.git
|
||||
tmp
|
||||
@@ -65,7 +79,6 @@ htmlcov/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
@@ -73,6 +86,11 @@ coverage.xml
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Ignore .cache except calibration
|
||||
.cache/*
|
||||
!.cache/calibration/
|
||||
!.cache/calibration/**
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
16
.gitattributes
vendored
16
.gitattributes
vendored
@@ -1,6 +1,20 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
*.memmap filter=lfs diff=lfs merge=lfs -text
|
||||
*.stl filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.json filter=lfs diff=lfs merge=lfs -text
|
||||
*.json !text !filter !merge !diff
|
||||
|
||||
14
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
14
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve LeRobot
|
||||
body:
|
||||
|
||||
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -21,7 +21,7 @@ Provide a simple way for the reviewer to try out your changes.
|
||||
|
||||
Examples:
|
||||
```bash
|
||||
DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
|
||||
pytest -sx tests/test_stuff.py::test_something
|
||||
```
|
||||
```bash
|
||||
python lerobot/scripts/train.py --some.option=true
|
||||
|
||||
26
.github/workflows/build-docker-images.yml
vendored
26
.github/workflows/build-docker-images.yml
vendored
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
|
||||
name: Builds
|
||||
@@ -8,6 +22,8 @@ on:
|
||||
schedule:
|
||||
- cron: "0 1 * * *"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -25,11 +41,14 @@ jobs:
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
@@ -60,11 +79,14 @@ jobs:
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
@@ -89,9 +111,13 @@ jobs:
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
24
.github/workflows/nightly-tests.yml
vendored
24
.github/workflows/nightly-tests.yml
vendored
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
|
||||
name: Nightly
|
||||
@@ -7,10 +21,10 @@ on:
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
||||
permissions: {}
|
||||
|
||||
# env:
|
||||
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
||||
jobs:
|
||||
run_all_tests_cpu:
|
||||
name: CPU
|
||||
@@ -30,13 +44,9 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Tests
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
run: pytest -v --cov=./lerobot --disable-warnings tests
|
||||
|
||||
- name: Tests end-to-end
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
run: make test-end-to-end
|
||||
|
||||
|
||||
|
||||
74
.github/workflows/quality.yml
vendored
74
.github/workflows/quality.yml
vendored
@@ -1,15 +1,29 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: Quality
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -19,7 +33,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
@@ -30,55 +46,27 @@ jobs:
|
||||
id: get-ruff-version
|
||||
run: |
|
||||
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
||||
echo "RUFF_VERSION=${RUFF_VERSION}" >> $GITHUB_ENV
|
||||
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Install Ruff
|
||||
run: python -m pip install "ruff==${{ env.RUFF_VERSION }}"
|
||||
env:
|
||||
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
|
||||
run: python -m pip install "ruff==${RUFF_VERSION}"
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check
|
||||
run: ruff check --output-format=github
|
||||
|
||||
- name: Ruff format
|
||||
run: ruff format --diff
|
||||
|
||||
|
||||
poetry_check:
|
||||
name: Poetry check
|
||||
typos:
|
||||
name: Typos
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check
|
||||
|
||||
|
||||
poetry_relax:
|
||||
name: Poetry relax
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Install poetry-relax
|
||||
run: poetry self add poetry-relax
|
||||
|
||||
- name: Poetry relax
|
||||
id: poetry_relax
|
||||
run: |
|
||||
output=$(poetry relax --check 2>&1)
|
||||
if echo "$output" | grep -q "Proposing updates"; then
|
||||
echo "$output"
|
||||
echo ""
|
||||
echo "Some dependencies have caret '^' version requirement added by poetry by default."
|
||||
echo "Please replace them with '>='. You can do this by hand or use poetry-relax to do this."
|
||||
exit 1
|
||||
else
|
||||
echo "$output"
|
||||
fi
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.29.10
|
||||
|
||||
31
.github/workflows/test-docker-build.yml
vendored
31
.github/workflows/test-docker-build.yml
vendored
@@ -1,15 +1,29 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
|
||||
name: Test Dockerfiles
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
# Run only when DockerFile files are modified
|
||||
- "docker/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -22,6 +36,8 @@ jobs:
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
@@ -30,21 +46,18 @@ jobs:
|
||||
files: docker/**
|
||||
json: "true"
|
||||
|
||||
- name: Run step if only the files listed above change
|
||||
- name: Run step if only the files listed above change # zizmor: ignore[template-injection]
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: set-matrix
|
||||
env:
|
||||
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
|
||||
run: |
|
||||
echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT
|
||||
|
||||
|
||||
build_modified_dockerfiles:
|
||||
name: Build modified Docker images
|
||||
needs: get_changed_files
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
||||
if: needs.get_changed_files.outputs.matrix != ''
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -52,9 +65,13 @@ jobs:
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
|
||||
101
.github/workflows/test.yml
vendored
101
.github/workflows/test.yml
vendored
@@ -1,16 +1,30 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "lerobot/**"
|
||||
- "tests/**"
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "pyproject.toml"
|
||||
- ".pre-commit-config.yaml"
|
||||
- "Makefile"
|
||||
- ".cache/**"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
@@ -19,113 +33,116 @@ on:
|
||||
- "tests/**"
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "pyproject.toml"
|
||||
- ".pre-commit-config.yaml"
|
||||
- "Makefile"
|
||||
- ".cache/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
UV_VERSION: "0.6.0"
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
name: Pytest
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev ffmpeg
|
||||
|
||||
- name: Install poetry
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
cache: "poetry"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --all-extras
|
||||
- name: Install lerobot (all extras)
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest tests -v --cov=./lerobot --durations=0 \
|
||||
uv run pytest tests -v --cov=./lerobot --durations=0 \
|
||||
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
&& rm -rf tests/outputs outputs
|
||||
|
||||
|
||||
pytest-minimal:
|
||||
name: Pytest (minimal install)
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --extras "test"
|
||||
- name: Install lerobot
|
||||
run: uv sync --extra "test"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest tests -v --cov=./lerobot --durations=0 \
|
||||
uv run pytest tests -v --cov=./lerobot --durations=0 \
|
||||
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
&& rm -rf tests/outputs outputs
|
||||
|
||||
|
||||
end-to-end:
|
||||
name: End-to-end
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
|
||||
- name: Install poetry
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
cache: "poetry"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
- name: Install lerobot (all extras)
|
||||
run: |
|
||||
poetry install --all-extras
|
||||
uv venv
|
||||
uv sync --all-extras
|
||||
|
||||
- name: venv
|
||||
run: |
|
||||
echo "PYTHON_PATH=${{ github.workspace }}/.venv/bin/python" >> $GITHUB_ENV
|
||||
|
||||
- name: Test end-to-end
|
||||
run: |
|
||||
|
||||
19
.github/workflows/trufflehog.yml
vendored
19
.github/workflows/trufflehog.yml
vendored
@@ -1,10 +1,23 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
on:
|
||||
push:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
@@ -14,6 +27,8 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
with:
|
||||
|
||||
24
.gitignore
vendored
24
.gitignore
vendored
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Logging
|
||||
logs
|
||||
tmp
|
||||
@@ -49,6 +63,10 @@ share/python-wheels/
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# uv/poetry lock files
|
||||
poetry.lock
|
||||
uv.lock
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
@@ -66,7 +84,6 @@ htmlcov/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
@@ -74,6 +91,11 @@ coverage.xml
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Ignore .cache except calibration
|
||||
.cache/*
|
||||
!.cache/calibration/
|
||||
!.cache/calibration/**
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
@@ -1,9 +1,24 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
exclude: ^(tests/data)
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
##### Style / Misc. #####
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.6.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
- id: debug-statements
|
||||
@@ -13,25 +28,34 @@ repos:
|
||||
- id: check-toml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.30.0
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.16.0
|
||||
rev: v3.19.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.5.2
|
||||
rev: v0.9.9
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/python-poetry/poetry
|
||||
rev: 1.8.0
|
||||
hooks:
|
||||
- id: poetry-check
|
||||
- id: poetry-lock
|
||||
args:
|
||||
- "--check"
|
||||
- "--no-update"
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.18.4
|
||||
rev: v8.24.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.4.1
|
||||
hooks:
|
||||
- id: zizmor
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.3
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: ["-c", "pyproject.toml"]
|
||||
additional_dependencies: ["bandit[toml]"]
|
||||
|
||||
@@ -20,7 +20,7 @@ Some of the ways you can contribute to 🤗 LeRobot:
|
||||
* Contributing to the examples or to the documentation.
|
||||
* Submitting issues related to bugs or desired new features.
|
||||
|
||||
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
|
||||
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](mailto:remi.cadene@huggingface.co).
|
||||
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
|
||||
|
||||
@@ -129,38 +129,71 @@ Follow these steps to start contributing:
|
||||
|
||||
🚨 **Do not** work on the `main` branch.
|
||||
|
||||
4. for development, we use `poetry` instead of just `pip` to easily track our dependencies.
|
||||
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
|
||||
Set up a development environment with conda or miniconda:
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
If you're using `uv`, it can manage python versions so you can instead do:
|
||||
```bash
|
||||
poetry install --sync --extras "dev test"
|
||||
uv venv --python 3.10 && source .venv/bin/activate
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry sync --extras "dev test"
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv sync --extra dev --extra test
|
||||
```
|
||||
|
||||
You can also install the project with all its dependencies (including environments):
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry install --sync --all-extras
|
||||
poetry sync --all-extras
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing.
|
||||
|
||||
Whichever command you chose to install the project (e.g. `poetry install --sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
|
||||
The equivalent of `pip install some-package`, would just be:
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
using `uv`
|
||||
```bash
|
||||
poetry lock --no-update
|
||||
uv add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry lock
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv lock
|
||||
```
|
||||
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
@@ -195,7 +228,7 @@ Follow these steps to start contributing:
|
||||
git commit
|
||||
```
|
||||
|
||||
Note, if you already commited some changes that have a wrong formatting, you can use:
|
||||
Note, if you already committed some changes that have a wrong formatting, you can use:
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
@@ -267,7 +300,7 @@ We use `pytest` in order to run the tests. From the root of the
|
||||
repository, here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
DATA_DIR="tests/data" python -m pytest -sv ./tests
|
||||
python -m pytest -sv ./tests
|
||||
```
|
||||
|
||||
|
||||
|
||||
250
Makefile
250
Makefile
@@ -1,11 +1,25 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
.PHONY: tests
|
||||
|
||||
PYTHON_PATH := $(shell which python)
|
||||
|
||||
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python
|
||||
POETRY_CHECK := $(shell command -v poetry)
|
||||
ifneq ($(POETRY_CHECK),)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
# If uv is installed and a virtual environment exists, use it
|
||||
UV_CHECK := $(shell command -v uv)
|
||||
ifneq ($(UV_CHECK),)
|
||||
PYTHON_PATH := $(shell .venv/bin/python)
|
||||
endif
|
||||
|
||||
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
|
||||
@@ -20,171 +34,109 @@ build-gpu:
|
||||
|
||||
test-end-to-end:
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-resume
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
|
||||
|
||||
test-act-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
policy.dim_model=64 \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
--policy.type=act \
|
||||
--policy.dim_model=64 \
|
||||
--policy.n_action_steps=20 \
|
||||
--policy.chunk_size=20 \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=aloha \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=4 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_freq=2 \
|
||||
--save_checkpoint=true \
|
||||
--log_freq=1 \
|
||||
--wandb.enable=false \
|
||||
--output_dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-train-resume:
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
|
||||
test-act-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/act/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-act-ete-train-amp:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
policy.dim_model=64 \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act_amp/ \
|
||||
training.image_transforms.enable=true \
|
||||
use_amp=true
|
||||
|
||||
test-act-ete-eval-amp:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/act_amp/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
use_amp=true
|
||||
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=aloha \
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
test-diffusion-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=diffusion \
|
||||
policy.down_dims=\[64,128,256\] \
|
||||
policy.diffusion_step_embed_dim=32 \
|
||||
policy.num_inference_steps=10 \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
--policy.type=diffusion \
|
||||
--policy.down_dims='[64,128,256]' \
|
||||
--policy.diffusion_step_embed_dim=32 \
|
||||
--policy.num_inference_steps=10 \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=pusht \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=2 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2 \
|
||||
--log_freq=1 \
|
||||
--wandb.enable=false \
|
||||
--output_dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=pusht \
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
test-tdmpc-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
env=xarm \
|
||||
env.task=XarmLift-v0 \
|
||||
dataset_repo_id=lerobot/xarm_lift_medium \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-train-with-online:
|
||||
python lerobot/scripts/train.py \
|
||||
env=pusht \
|
||||
env.gym.obs_type=environment_state_agent_pos \
|
||||
policy=tdmpc_pusht_keypoints \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=10 \
|
||||
device=$(DEVICE) \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=20 \
|
||||
training.save_checkpoint=false \
|
||||
training.save_freq=10 \
|
||||
training.batch_size=2 \
|
||||
training.online_rollout_n_episodes=2 \
|
||||
training.online_rollout_batch_size=2 \
|
||||
training.online_steps_between_rollouts=10 \
|
||||
training.online_buffer_capacity=15 \
|
||||
eval.use_async_envs=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc_online/
|
||||
|
||||
--policy.type=tdmpc \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=xarm \
|
||||
--env.task=XarmLift-v0 \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/xarm_lift_medium \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=2 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2 \
|
||||
--log_freq=1 \
|
||||
--wandb.enable=false \
|
||||
--output_dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-default-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-act-pusht-tutorial:
|
||||
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
|
||||
python lerobot/scripts/train.py \
|
||||
policy=created_by_Makefile.yaml \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=$(DEVICE) \
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act_pusht/
|
||||
rm lerobot/configs/policy/created_by_Makefile.yaml
|
||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=xarm \
|
||||
--env.episode_length=5 \
|
||||
--env.task=XarmLift-v0 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
122
README.md
122
README.md
@@ -23,20 +23,30 @@
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">Hot new tutorial: Getting started with real-world robots</a></p>
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">
|
||||
Build Your Own SO-100 Robot!</a></p>
|
||||
</h2>
|
||||
|
||||
<div align="center">
|
||||
<img src="media/tutorial/koch_v1_1_leader_follower.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="50%">
|
||||
<p>We just dropped an in-depth tutorial on how to build your own robot!</p>
|
||||
<p>Teach it new skills by showing it a few moves with just a laptop.</p>
|
||||
<p>Then watch your homemade robot act autonomously 🤯</p>
|
||||
<p>For more info, see [our thread on X](https://x.com/RemiCadene/status/1825455895561859185) or [our tutorial page](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md).</p>
|
||||
<img src="media/so100/leader_follower.webp?raw=true" alt="SO-100 leader and follower arms" title="SO-100 leader and follower arms" width="50%">
|
||||
|
||||
<p><strong>Meet the SO-100 – Just $110 per arm!</strong></p>
|
||||
<p>Train it in minutes with a few simple moves on your laptop.</p>
|
||||
<p>Then sit back and watch your creation act autonomously! 🤯</p>
|
||||
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">
|
||||
Get the full SO-100 tutorial here.</a></p>
|
||||
|
||||
<p>Want to take it to the next level? Make your SO-100 mobile by building LeKiwi!</p>
|
||||
<p>Check out the <a href="https://github.com/huggingface/lerobot/blob/main/examples/11_use_lekiwi.md">LeKiwi tutorial</a> and bring your robot to life on wheels.</p>
|
||||
|
||||
<img src="media/lekiwi/kiwi.webp?raw=true" alt="LeKiwi mobile robot" title="LeKiwi mobile robot" width="50%">
|
||||
</div>
|
||||
|
||||
<br/>
|
||||
|
||||
<h3 align="center">
|
||||
<p>State-of-the-art AI for real-world robotics</p>
|
||||
<p>LeRobot: State-of-the-art AI for real-world robotics</p>
|
||||
</h3>
|
||||
|
||||
---
|
||||
@@ -54,9 +64,9 @@
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img src="http://remicadene.com/assets/gif/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
<td><img src="media/gym/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="media/gym/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="media/gym/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ACT policy on ALOHA env</td>
|
||||
@@ -67,7 +77,7 @@
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
@@ -121,10 +131,7 @@ wandb login
|
||||
├── examples # contains demonstration examples, start here to learn about LeRobot
|
||||
| └── advanced # contains even more examples for those who have mastered the basics
|
||||
├── lerobot
|
||||
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
||||
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
||||
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, xarm.yaml
|
||||
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
|
||||
| ├── configs # contains config classes with all options that you can override in the command line
|
||||
| ├── common # contains classes and utilities
|
||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
|
||||
| | ├── envs # various sim environments: aloha, pusht, xarm
|
||||
@@ -143,7 +150,7 @@ wandb login
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
|
||||
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
```bash
|
||||
@@ -152,10 +159,12 @@ python lerobot/scripts/visualize_dataset.py \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the root `DATA_DIR` environment variable (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
```bash
|
||||
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--local-files-only 1 \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
@@ -207,12 +216,10 @@ dataset attributes:
|
||||
|
||||
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
||||
- hf_dataset stored using Hugging Face datasets library serialization to parquet
|
||||
- videos are stored in mp4 format to save space or png files
|
||||
- episode_data_index saved using `safetensor` tensor serialization format
|
||||
- stats saved using `safetensor` tensor serialization format
|
||||
- info are saved using JSON
|
||||
- videos are stored in mp4 format to save space
|
||||
- metadata are stored in plain json/jsonl files
|
||||
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to your root dataset folder as illustrated in the above section on dataset visualization.
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
@@ -221,80 +228,48 @@ Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrat
|
||||
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
-p lerobot/diffusion_pusht \
|
||||
eval.n_episodes=10 \
|
||||
eval.batch_size=10
|
||||
--policy.path=lerobot/diffusion_pusht \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
--eval.n_episodes=10 \
|
||||
--use_amp=false \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/eval.py -p {OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
python lerobot/scripts/eval.py --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrate how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
|
||||
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
env=aloha \
|
||||
env.task=AlohaInsertion-v0 \
|
||||
dataset_repo_id=lerobot/aloha_sim_insertion_human \
|
||||
```
|
||||
|
||||
The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
|
||||
```bash
|
||||
hydra.run.dir=your/new/experiment/dir
|
||||
```
|
||||
|
||||
In the experiment directory there will be a folder called `checkpoints` which will have the following structure:
|
||||
|
||||
```bash
|
||||
checkpoints
|
||||
├── 000250 # checkpoint_dir for training step 250
|
||||
│ ├── pretrained_model # Hugging Face pretrained model dir
|
||||
│ │ ├── config.json # Hugging Face pretrained model config
|
||||
│ │ ├── config.yaml # consolidated Hydra config
|
||||
│ │ ├── model.safetensors # model weights
|
||||
│ │ └── README.md # Hugging Face model card
|
||||
│ └── training_state.pth # optimizer/scheduler/rng state and training step
|
||||
```
|
||||
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
|
||||
|
||||
```bash
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](./examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
|
||||
|
||||

|
||||
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
#### Reproduce state-of-the-art (SOTA)
|
||||
|
||||
We have organized our configuration files (found under [`lerobot/configs`](./lerobot/configs)) such that they reproduce SOTA results from a given model variant in their respective original works. Simply running:
|
||||
|
||||
We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances.
|
||||
You can reproduce their training by loading the config from their run. Simply running:
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=diffusion env=pusht
|
||||
python lerobot/scripts/train.py --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
Pretrained policies, along with reproduction details, can be found under the "Models" section of https://huggingface.co/lerobot.
|
||||
|
||||
## Contribute
|
||||
|
||||
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Add a new dataset
|
||||
<!-- ### Add a new dataset
|
||||
|
||||
To add a dataset to the hub, you need to login using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
@@ -312,7 +287,7 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
|
||||
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
||||
|
||||
If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py).
|
||||
If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py). -->
|
||||
|
||||
|
||||
### Add a pretrained policy
|
||||
@@ -322,7 +297,7 @@ Once you have trained a policy you may upload it to the Hugging Face hub using a
|
||||
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
|
||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
|
||||
- `train_config.json`: A consolidated configuration containing all parameter userd for training. The policy configuration should match `config.json` exactly. Thisis useful for anyone who wants to evaluate your policy or for reproducibility.
|
||||
|
||||
To upload these to the hub, run the following:
|
||||
```bash
|
||||
@@ -409,3 +384,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
@@ -21,7 +21,7 @@ How to decode videos?
|
||||
|
||||
## Variables
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
@@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
|
||||
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
@@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
@@ -114,9 +114,9 @@ We tried to measure the most impactful parameters for both encoding and decoding
|
||||
|
||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
||||
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
- `torchaudio`
|
||||
|
||||
@@ -67,6 +67,7 @@ def parse_int_or_none(value) -> int | None:
|
||||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
# TODO(Steven): Seems this API has changed
|
||||
if dataset.video:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
@@ -266,7 +267,7 @@ def benchmark_encoding_decoding(
|
||||
)
|
||||
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
|
||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
images_size_bytes = get_directory_size(imgs_dir)
|
||||
|
||||
@@ -1,32 +1,29 @@
|
||||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
|
||||
# Configure environment variables
|
||||
ARG PYTHON_VERSION
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MUJOCO_GL="egl"
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
# Install dependencies and set up Python in a single layer
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||
&& python -m venv /opt/venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, koch]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Execute in bash shell rather than python
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
sed gawk grep curl wget zip unzip \
|
||||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher \
|
||||
speech-dispatcher portaudio19-dev libgeos-dev \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -1,30 +1,24 @@
|
||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||
|
||||
# Configure image
|
||||
# Configure environment variables
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher \
|
||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MUJOCO_GL="egl"
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
# Install dependencies and set up Python in a single layer
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||
&& python -m venv /opt/venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, koch]"
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
|
||||
|
||||
614
examples/10_use_so100.md
Normal file
614
examples/10_use_so100.md
Normal file
@@ -0,0 +1,614 @@
|
||||
# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [A. Source the parts](#a-source-the-parts)
|
||||
- [B. Install LeRobot](#b-install-lerobot)
|
||||
- [C. Configure the Motors](#c-configure-the-motors)
|
||||
- [D. Step-by-Step Assembly Instructions](#d-step-by-step-assembly-instructions)
|
||||
- [E. Calibrate](#e-calibrate)
|
||||
- [F. Teleoperate](#f-teleoperate)
|
||||
- [G. Record a dataset](#g-record-a-dataset)
|
||||
- [H. Visualize a dataset](#h-visualize-a-dataset)
|
||||
- [I. Replay an episode](#i-replay-an-episode)
|
||||
- [J. Train a policy](#j-train-a-policy)
|
||||
- [K. Evaluate your policy](#k-evaluate-your-policy)
|
||||
- [L. More Information](#l-more-information)
|
||||
|
||||
## A. Source the parts
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts,
|
||||
and advice if it's your first time printing or if you don't own a 3D printer.
|
||||
|
||||
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## B. Install LeRobot
|
||||
|
||||
> [!TIP]
|
||||
> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line)
|
||||
|
||||
On your computer:
|
||||
|
||||
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||
|
||||
#### 2. Restart shell
|
||||
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
#### 3. Create and activate a fresh conda environment for lerobot
|
||||
|
||||
<details>
|
||||
<summary><strong>Video install instructions</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
```
|
||||
|
||||
Then activate your conda environment (do this each time you open a shell to use lerobot!):
|
||||
```bash
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
#### 4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
|
||||
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
||||
|
||||
## C. Configure the motors
|
||||
|
||||
> [!NOTE]
|
||||
> Throughout this tutorial you will find videos on how to do the steps, the full video tutorial can be found here: [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I).
|
||||
|
||||
### 1. Find the USB ports associated to each arm
|
||||
|
||||
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm. It's convenient to label them and write on each motor if it's for the follower `F` or for the leader `L` and it's ID from 1 to 6 (F1...F6 and L1...L6).
|
||||
|
||||
#### a. Run the script to find port
|
||||
|
||||
<details>
|
||||
<summary><strong>Video finding port</strong></summary>
|
||||
<video src="https://github.com/user-attachments/assets/4a21a14d-2046-4805-93c4-ee97a30ba33f"></video>
|
||||
<video src="https://github.com/user-attachments/assets/1cc3aecf-c16d-4ff9-aec7-8c175afbbce2"></video>
|
||||
</details>
|
||||
|
||||
To find the port for each bus servo adapter, run the utility script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
#### b. Example outputs
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
#### c. Troubleshooting
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### d. Update config file
|
||||
|
||||
IMPORTANTLY: Now that you have your ports, update the **port** default values of [`SO100RobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("so100")
|
||||
@dataclass
|
||||
class So100RobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/so100"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Assembling the Base
|
||||
Let's begin with assembling the follower arm base
|
||||
|
||||
#### a. Set IDs for all 12 motors
|
||||
|
||||
<details>
|
||||
<summary><strong>Video configuring motor</strong></summary>
|
||||
<video src="https://github.com/user-attachments/assets/ef9b3317-2e11-4858-b9d3-f0a02fb48ecf"></video>
|
||||
<video src="https://github.com/user-attachments/assets/f36b5ed5-c803-4ebe-8947-b39278776a0d"></video>
|
||||
</details>
|
||||
|
||||
Plug your first motor F1 and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate. Replace the text after --port to the corresponding follower control board port and run this command in cmd:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
|
||||
#### b. Remove the gears of the 6 leader motors
|
||||
|
||||
<details>
|
||||
<summary><strong>Video removing gears</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/0c95b88c-5b85-413d-ba19-aee2f864f2a7"></video>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
## D. Step-by-Step Assembly Instructions
|
||||
|
||||
**Step 1: Clean Parts**
|
||||
- Remove all support material from the 3D-printed parts.
|
||||
---
|
||||
|
||||
### Additional Guidance
|
||||
|
||||
<details>
|
||||
<summary><strong>Video assembling arms</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/488a39de-0189-4461-9de3-05b015f90cca"></video>
|
||||
|
||||
</details>
|
||||
|
||||
**Note:**
|
||||
This video provides visual guidance for assembling the arms, but it doesn't specify when or how to do the wiring. Inserting the cables beforehand is much easier than doing it afterward. The first arm may take a bit more than 1 hour to assemble, but once you get used to it, you can assemble the second arm in under 1 hour.
|
||||
|
||||
---
|
||||
|
||||
### First Motor
|
||||
|
||||
**Step 2: Insert Wires**
|
||||
- Insert two wires into the first motor.
|
||||
|
||||
<img src="../media/tutorial/img1.jpg" style="height:300px;">
|
||||
|
||||
**Step 3: Install in Base**
|
||||
- Place the first motor into the base.
|
||||
|
||||
<img src="../media/tutorial/img2.jpg" style="height:300px;">
|
||||
|
||||
**Step 4: Secure Motor**
|
||||
- Fasten the motor with 4 screws. Two from the bottom and two from top.
|
||||
|
||||
**Step 5: Attach Motor Holder**
|
||||
- Slide over the first motor holder and fasten it using two screws (one on each side).
|
||||
|
||||
<img src="../media/tutorial/img4.jpg" style="height:300px;">
|
||||
|
||||
**Step 6: Attach Motor Horns**
|
||||
- Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears.
|
||||
|
||||
<img src="../media/tutorial/img5.jpg" style="height:300px;">
|
||||
<details>
|
||||
<summary><strong>Video adding motor horn</strong></summary>
|
||||
<video src="https://github.com/user-attachments/assets/ef3391a4-ad05-4100-b2bd-1699bf86c969"></video>
|
||||
</details>
|
||||
|
||||
**Step 7: Attach Shoulder Part**
|
||||
- Route one wire to the back of the robot and the other to the left or in photo towards you (see photo).
|
||||
- Attach the shoulder part.
|
||||
|
||||
<img src="../media/tutorial/img6.jpg" style="height:300px;">
|
||||
|
||||
**Step 8: Secure Shoulder**
|
||||
- Tighten the shoulder part with 4 screws on top and 4 on the bottom
|
||||
*(access bottom holes by turning the shoulder).*
|
||||
|
||||
---
|
||||
|
||||
### Second Motor Assembly
|
||||
|
||||
**Step 9: Install Motor 2**
|
||||
- Slide the second motor in from the top and link the wire from motor 1 to motor 2.
|
||||
|
||||
<img src="../media/tutorial/img8.jpg" style="height:300px;">
|
||||
|
||||
**Step 10: Attach Shoulder Holder**
|
||||
- Add the shoulder motor holder.
|
||||
- Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo).
|
||||
- This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor.
|
||||
|
||||
<div style="display: flex;">
|
||||
<img src="../media/tutorial/img9.jpg" style="height:250px;">
|
||||
<img src="../media/tutorial/img10.jpg" style="height:250px;">
|
||||
<img src="../media/tutorial/img12.jpg" style="height:250px;">
|
||||
</div>
|
||||
|
||||
**Step 11: Secure Motor 2**
|
||||
- Fasten the second motor with 4 screws.
|
||||
|
||||
**Step 12: Attach Motor Horn**
|
||||
- Attach both motor horns to motor 2, again use the horn screw.
|
||||
|
||||
**Step 13: Attach Base**
|
||||
- Install the base attachment using 2 screws.
|
||||
|
||||
<img src="../media/tutorial/img11.jpg" style="height:300px;">
|
||||
|
||||
**Step 14: Attach Upper Arm**
|
||||
- Attach the upper arm with 4 screws on each side.
|
||||
|
||||
<img src="../media/tutorial/img13.jpg" style="height:300px;">
|
||||
|
||||
---
|
||||
|
||||
### Third Motor Assembly
|
||||
|
||||
**Step 15: Install Motor 3**
|
||||
- Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws.
|
||||
|
||||
**Step 16: Attach Motor Horn**
|
||||
- Attach both motor horns to motor 3 and secure one again with a horn screw.
|
||||
|
||||
<img src="../media/tutorial/img14.jpg" style="height:300px;">
|
||||
|
||||
**Step 17: Attach Forearm**
|
||||
- Connect the forearm to motor 3 using 4 screws on each side.
|
||||
|
||||
<img src="../media/tutorial/img15.jpg" style="height:300px;">
|
||||
|
||||
---
|
||||
|
||||
### Fourth Motor Assembly
|
||||
|
||||
**Step 18: Install Motor 4**
|
||||
- Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw.
|
||||
|
||||
<div style="display: flex;">
|
||||
<img src="../media/tutorial/img16.jpg" style="height:300px;">
|
||||
<img src="../media/tutorial/img19.jpg" style="height:300px;">
|
||||
</div>
|
||||
|
||||
**Step 19: Attach Motor Holder 4**
|
||||
- Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo).
|
||||
|
||||
<img src="../media/tutorial/img17.jpg" style="height:300px;">
|
||||
|
||||
**Step 20: Secure Motor 4 & Attach Horn**
|
||||
- Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw.
|
||||
|
||||
<img src="../media/tutorial/img18.jpg" style="height:300px;">
|
||||
|
||||
---
|
||||
|
||||
### Wrist Assembly
|
||||
|
||||
**Step 21: Install Motor 5**
|
||||
- Insert motor 5 into the wrist holder and secure it with 2 front screws.
|
||||
|
||||
<img src="../media/tutorial/img20.jpg" style="height:300px;">
|
||||
|
||||
**Step 22: Attach Wrist**
|
||||
- Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper.
|
||||
- Secure the wrist to motor 4 using 4 screws on both sides.
|
||||
|
||||
<img src="../media/tutorial/img22.jpg" style="height:300px;">
|
||||
|
||||
**Step 23: Attach Wrist Horn**
|
||||
- Install only one motor horn on the wrist motor and secure it with a horn screw.
|
||||
|
||||
<img src="../media/tutorial/img23.jpg" style="height:300px;">
|
||||
|
||||
---
|
||||
|
||||
### Follower Configuration
|
||||
|
||||
**Step 24: Attach Gripper**
|
||||
- Attach the gripper to motor 5.
|
||||
|
||||
<img src="../media/tutorial/img24.jpg" style="height:300px;">
|
||||
|
||||
**Step 25: Install Gripper Motor**
|
||||
- Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side.
|
||||
|
||||
<img src="../media/tutorial/img25.jpg" style="height:300px;">
|
||||
|
||||
**Step 26: Attach Gripper Horn & Claw**
|
||||
- Attach the motor horns and again use a horn screw.
|
||||
- Install the gripper claw and secure it with 4 screws on both sides.
|
||||
|
||||
<img src="../media/tutorial/img26.jpg" style="height:300px;">
|
||||
|
||||
**Step 27: Mount Controller**
|
||||
- Attach the motor controller on the back.
|
||||
|
||||
<div style="display: flex;">
|
||||
<img src="../media/tutorial/img27.jpg" style="height:300px;">
|
||||
<img src="../media/tutorial/img28.jpg" style="height:300px;">
|
||||
</div>
|
||||
|
||||
*Assembly complete – proceed to Leader arm assembly.*
|
||||
|
||||
---
|
||||
|
||||
### Leader Configuration
|
||||
|
||||
For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors.
|
||||
|
||||
**Step 24: Attach Leader Holder**
|
||||
- Mount the leader holder onto the wrist and secure it with a screw.
|
||||
|
||||
<img src="../media/tutorial/img29.jpg" style="height:300px;">
|
||||
|
||||
**Step 25: Attach Handle**
|
||||
- Attach the handle to motor 5 using 4 screws.
|
||||
|
||||
<img src="../media/tutorial/img30.jpg" style="height:300px;">
|
||||
|
||||
**Step 26: Install Gripper Motor**
|
||||
- Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire.
|
||||
|
||||
<img src="../media/tutorial/img31.jpg" style="height:300px;">
|
||||
|
||||
**Step 27: Attach Trigger**
|
||||
- Attach the follower trigger with 4 screws.
|
||||
|
||||
<img src="../media/tutorial/img32.jpg" style="height:300px;">
|
||||
|
||||
**Step 28: Mount Controller**
|
||||
- Attach the motor controller on the back.
|
||||
|
||||
<div style="display: flex;">
|
||||
<img src="../media/tutorial/img27.jpg" style="height:300px;">
|
||||
<img src="../media/tutorial/img28.jpg" style="height:300px;">
|
||||
</div>
|
||||
|
||||
*Assembly complete – proceed to calibration.*
|
||||
|
||||
|
||||
## E. Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
|
||||
#### a. Manual calibration of follower arm
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_follower"]'
|
||||
```
|
||||
|
||||
#### b. Manual calibration of leader arm
|
||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_leader"]'
|
||||
```
|
||||
|
||||
## F. Teleoperate
|
||||
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
|
||||
#### a. Teleop with displaying cameras
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
## G. Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/so100_test \
|
||||
--control.tags='["so100","tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
Note: You can resume recording by adding `--control.resume=true`.
|
||||
|
||||
## H. Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/so100_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
--local-files-only 1
|
||||
```
|
||||
|
||||
## I. Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/so100_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
## J. Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/so100_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so100_test \
|
||||
--job_name=act_so100_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||
|
||||
## K. Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/eval_act_so100_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_so100_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so100_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so100_test`).
|
||||
|
||||
## L. More Information
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).
|
||||
585
examples/11_use_lekiwi.md
Normal file
585
examples/11_use_lekiwi.md
Normal file
@@ -0,0 +1,585 @@
|
||||
# Using the [LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi) Robot with LeRobot
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [A. Source the parts](#a-source-the-parts)
|
||||
- [B. Install software Pi](#b-install-software-on-pi)
|
||||
- [C. Setup LeRobot laptop/pc](#c-install-lerobot-on-laptop)
|
||||
- [D. Assemble the arms](#d-assembly)
|
||||
- [E. Calibrate](#e-calibration)
|
||||
- [F. Teleoperate](#f-teleoperate)
|
||||
- [G. Record a dataset](#g-record-a-dataset)
|
||||
- [H. Visualize a dataset](#h-visualize-a-dataset)
|
||||
- [I. Replay an episode](#i-replay-an-episode)
|
||||
- [J. Train a policy](#j-train-a-policy)
|
||||
- [K. Evaluate your policy](#k-evaluate-your-policy)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#mobile-so-100-arm`](https://discord.com/channels/1216765309076115607/1318390825528332371).
|
||||
|
||||
## A. Source the parts
|
||||
|
||||
Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts, and advice if it's your first time printing or if you don't own a 3D printer.
|
||||
|
||||
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
### Wired version
|
||||
If you have the **wired** LeKiwi version you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating.
|
||||
|
||||
## B. Install software on Pi
|
||||
Now we have to setup the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board.
|
||||
|
||||
### Install OS
|
||||
For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu.
|
||||
|
||||
### Setup SSH
|
||||
After setting up your Pi, you should enable and setup [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can login into the Pi from your laptop without requiring a screen, keyboard and mouse in the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension.
|
||||
|
||||
### Install LeRobot
|
||||
|
||||
On your Raspberry Pi:
|
||||
|
||||
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||
|
||||
#### 2. Restart shell
|
||||
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
#### 3. Create and activate a fresh conda environment for lerobot
|
||||
|
||||
<details>
|
||||
<summary><strong>Video install instructions</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
```
|
||||
|
||||
Then activate your conda environment (do this each time you open a shell to use lerobot!):
|
||||
```bash
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
#### 4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
## C. Install LeRobot on laptop
|
||||
If you already have install LeRobot on your laptop you can skip this step, otherwise please follow along as we do the same steps we did on the Pi.
|
||||
|
||||
> [!TIP]
|
||||
> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line)
|
||||
|
||||
On your computer:
|
||||
|
||||
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||
|
||||
#### 2. Restart shell
|
||||
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
#### 3. Create and activate a fresh conda environment for lerobot
|
||||
|
||||
<details>
|
||||
<summary><strong>Video install instructions</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
```
|
||||
|
||||
Then activate your conda environment (do this each time you open a shell to use lerobot!):
|
||||
```bash
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
#### 4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
|
||||
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
||||
|
||||
# D. Assembly
|
||||
|
||||
First we will assemble the two SO100 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base.
|
||||
|
||||
## SO100 Arms
|
||||
### Configure motors
|
||||
The instructions for configuring the motors can be found [Here](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#c-configure-the-motors) in step C of the SO100 tutorial. Besides the ID's for the arm motors we also need to set the motor ID's for the mobile base. These needs to be in a specific order to work. Below an image of the motor ID's and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ID's for the wheels are 7, 8 and 9.
|
||||
|
||||
<img src="../media/lekiwi/motor_ids.webp?raw=true" alt="Motor ID's for mobile robot" title="Motor ID's for mobile robot" width="60%">
|
||||
|
||||
### Assemble arms
|
||||
[Assemble arms instruction](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#d-assemble-the-arms)
|
||||
|
||||
## Mobile base (LeKiwi)
|
||||
[Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi)
|
||||
|
||||
### Update config
|
||||
Both config files on the LeKiwi LeRobot and on the laptop should be the same. First we should find the Ip address of the Raspberry Pi of the mobile manipulator. This is the same Ip address used in SSH. We also need the usb port of the control board of the leader arm on the laptop and the port of the control board on LeKiwi. We can find these ports with the following script.
|
||||
|
||||
#### a. Run the script to find port
|
||||
|
||||
<details>
|
||||
<summary><strong>Video finding port</strong></summary>
|
||||
<video src="https://github.com/user-attachments/assets/4a21a14d-2046-4805-93c4-ee97a30ba33f"></video>
|
||||
<video src="https://github.com/user-attachments/assets/1cc3aecf-c16d-4ff9-aec7-8c175afbbce2"></video>
|
||||
</details>
|
||||
|
||||
To find the port for each bus servo adapter, run the utility script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
#### b. Example outputs
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
#### c. Troubleshooting
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### d. Update config file
|
||||
|
||||
IMPORTANTLY: Now that you have your ports of leader and follower arm and ip address of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("lekiwi")
|
||||
@dataclass
|
||||
class LeKiwiRobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# Network Configuration
|
||||
ip: str = "172.17.133.91"
|
||||
port: int = 5555
|
||||
video_port: int = 5556
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"mobile": OpenCVCameraConfig(camera_index="/dev/video0", fps=30, width=640, height=480),
|
||||
"mobile2": OpenCVCameraConfig(camera_index="/dev/video2", fps=30, width=640, height=480),
|
||||
}
|
||||
)
|
||||
|
||||
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/ttyACM0",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
"left_wheel": (7, "sts3215"),
|
||||
"back_wheel": (8, "sts3215"),
|
||||
"right_wheel": (9, "sts3215"),
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
teleop_keys: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
# Movement
|
||||
"forward": "w",
|
||||
"backward": "s",
|
||||
"left": "a",
|
||||
"right": "d",
|
||||
"rotate_left": "z",
|
||||
"rotate_right": "x",
|
||||
# Speed control
|
||||
"speed_up": "r",
|
||||
"speed_down": "f",
|
||||
# quit teleop
|
||||
"quit": "q",
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
```
|
||||
|
||||
## Wired version
|
||||
|
||||
For the wired LeKiwi version your configured IP address should refer to your own laptop (127.0.0.1), because leader arm and LeKiwi are in this case connected to own laptop. Below and example configuration for this wired setup:
|
||||
```python
|
||||
@RobotConfig.register_subclass("lekiwi")
|
||||
@dataclass
|
||||
class LeKiwiRobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# Network Configuration
|
||||
ip: str = "127.0.0.1"
|
||||
port: int = 5555
|
||||
video_port: int = 5556
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"front": OpenCVCameraConfig(
|
||||
camera_index=0, fps=30, width=640, height=480, rotation=90
|
||||
),
|
||||
"wrist": OpenCVCameraConfig(
|
||||
camera_index=1, fps=30, width=640, height=480, rotation=180
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431061",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
"left_wheel": (7, "sts3215"),
|
||||
"back_wheel": (8, "sts3215"),
|
||||
"right_wheel": (9, "sts3215"),
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
teleop_keys: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
# Movement
|
||||
"forward": "w",
|
||||
"backward": "s",
|
||||
"left": "a",
|
||||
"right": "d",
|
||||
"rotate_left": "z",
|
||||
"rotate_right": "x",
|
||||
# Speed control
|
||||
"speed_up": "r",
|
||||
"speed_down": "f",
|
||||
# quit teleop
|
||||
"quit": "q",
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
```
|
||||
|
||||
# E. Calibration
|
||||
Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated.
|
||||
|
||||
|
||||
### Calibrate follower arm (on mobile base)
|
||||
> [!IMPORTANT]
|
||||
> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/lekiwi/mobile_calib_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_follower"]'
|
||||
```
|
||||
|
||||
### Wired version
|
||||
If you have the **wired** LeKiwi version please run all commands including this calibration command on your laptop.
|
||||
|
||||
### Calibrate leader arm
|
||||
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script (on your laptop/pc) to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_leader"]'
|
||||
```
|
||||
|
||||
# F. Teleoperate
|
||||
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=remote_robot
|
||||
```
|
||||
|
||||
Then on your laptop, also run `conda activate lerobot` and this script:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=teleoperate \
|
||||
--control.fps=30
|
||||
```
|
||||
|
||||
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
||||
|------------|-------------------|-----------------------|
|
||||
| Fast | 0.4 | 90 |
|
||||
| Medium | 0.25 | 60 |
|
||||
| Slow | 0.1 | 30 |
|
||||
|
||||
|
||||
| Key | Action |
|
||||
|------|--------------------------------|
|
||||
| W | Move forward |
|
||||
| A | Move left |
|
||||
| S | Move backward |
|
||||
| D | Move right |
|
||||
| Z | Turn left |
|
||||
| X | Turn right |
|
||||
| R | Increase speed |
|
||||
| F | Decrease speed |
|
||||
|
||||
> [!TIP]
|
||||
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
|
||||
|
||||
### Wired version
|
||||
If you have the **wired** LeKiwi version please run all commands including both these teleoperation commands on your laptop.
|
||||
|
||||
## Troubleshoot communication
|
||||
|
||||
If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue.
|
||||
|
||||
### 1. Verify IP Address Configuration
|
||||
Make sure that the correct ip for the Pi is set in the configuration file. To check the Raspberry Pi's IP address, run (on the Pi command line):
|
||||
```bash
|
||||
hostname -I
|
||||
```
|
||||
|
||||
### 2. Check if Pi is reachable from laptop/pc
|
||||
Try pinging the Raspberry Pi from your laptop:
|
||||
```bach
|
||||
ping <your_pi_ip_address>
|
||||
```
|
||||
|
||||
If the ping fails:
|
||||
- Ensure the Pi is powered on and connected to the same network.
|
||||
- Check if SSH is enabled on the Pi.
|
||||
|
||||
### 3. Try SSH connection
|
||||
If you can't SSH into the Pi, it might not be properly connected. Use:
|
||||
```bash
|
||||
ssh <your_pi_user_name>@<your_pi_ip_address>
|
||||
```
|
||||
If you get a connection error:
|
||||
- Ensure SSH is enabled on the Pi by running:
|
||||
```bash
|
||||
sudo raspi-config
|
||||
```
|
||||
Then navigate to: **Interfacing Options -> SSH** and enable it.
|
||||
|
||||
### 4. Same config file
|
||||
Make sure the configuration file on both your laptop/pc and the Raspberry Pi is the same.
|
||||
|
||||
# G. Record a dataset
|
||||
Once you're familiar with teleoperation, you can record your first dataset with LeKiwi.
|
||||
|
||||
To start the program on LeKiwi, SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=remote_robot
|
||||
```
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
On your laptop then run this command to record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/lekiwi_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
Note: You can resume recording by adding `--control.resume=true`.
|
||||
|
||||
### Wired version
|
||||
If you have the **wired** LeKiwi version please run all commands including both these record dataset commands on your laptop.
|
||||
|
||||
# H. Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/lekiwi_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/lekiwi_test \
|
||||
--local-files-only 1
|
||||
```
|
||||
|
||||
# I. Replay an episode
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/lekiwi_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
## J. Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/lekiwi_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_lekiwi_test \
|
||||
--job_name=act_lekiwi_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`.
|
||||
|
||||
## K. Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Drive to the red block and pick it up" \
|
||||
--control.repo_id=${HF_USER}/eval_act_lekiwi_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_lekiwi_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_lekiwi_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_lekiwi_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_lekiwi_test`).
|
||||
335
examples/11_use_moss.md
Normal file
335
examples/11_use_moss.md
Normal file
@@ -0,0 +1,335 @@
|
||||
This tutorial explains how to use [Moss v1](https://github.com/jess-moss/moss-robot-arms) with LeRobot.
|
||||
|
||||
## Source the parts
|
||||
|
||||
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials with link to source the parts, as well as the instructions to 3D print the parts and advice if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## Configure the motors
|
||||
|
||||
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the use of our scripts below.
|
||||
|
||||
**Find USB ports associated to your arms**
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### Update config file
|
||||
|
||||
IMPORTANTLY: Now that you have your ports, update the **port** default values of [`MossRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("moss")
|
||||
@dataclass
|
||||
class MossRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/moss"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Configure your motors**
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
**Remove the gears of the 6 leader motors**
|
||||
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
**Add motor horn to the motors**
|
||||
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). For Moss v1, you need to align the holes on the motor horn to the motor spline to be approximately 3, 6, 9 and 12 o'clock.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## Assemble the arms
|
||||
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your Moss v1 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one Moss v1 robot to work on another.
|
||||
|
||||
**Manual calibration of follower arm**
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/moss/follower_zero.webp?raw=true" alt="Moss v1 follower arm zero position" title="Moss v1 follower arm zero position" style="width:100%;"> | <img src="../media/moss/follower_rotated.webp?raw=true" alt="Moss v1 follower arm rotated position" title="Moss v1 follower arm rotated position" style="width:100%;"> | <img src="../media/moss/follower_rest.webp?raw=true" alt="Moss v1 follower arm rest position" title="Moss v1 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_follower"]'
|
||||
```
|
||||
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/moss/leader_zero.webp?raw=true" alt="Moss v1 leader arm zero position" title="Moss v1 leader arm zero position" style="width:100%;"> | <img src="../media/moss/leader_rotated.webp?raw=true" alt="Moss v1 leader arm rotated position" title="Moss v1 leader arm rotated position" style="width:100%;"> | <img src="../media/moss/leader_rest.webp?raw=true" alt="Moss v1 leader arm rest position" title="Moss v1 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_leader"]'
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with Moss v1.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/moss_test \
|
||||
--control.tags='["moss","tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
Note: You can resume recording by adding `--control.resume=true`.
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/moss_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--local-files-only 1
|
||||
```
|
||||
|
||||
## Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/moss_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/moss_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_moss_test \
|
||||
--job_name=act_moss_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/eval_act_moss_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_moss_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_moss_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_moss_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_moss_test`).
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel [`#moss-arm`](https://discord.com/channels/1216765309076115607/1275374638985252925).
|
||||
@@ -1,80 +1,136 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||
|
||||
Features included in this script:
|
||||
- Loading a dataset and accessing its properties.
|
||||
- Filtering data by episode number.
|
||||
- Converting tensor data for visualization.
|
||||
- Saving video files from dataset frames.
|
||||
- Viewing a dataset's metadata and exploring its properties.
|
||||
- Loading an existing dataset from the hub or a subset of it.
|
||||
- Accessing frames by episode number.
|
||||
- Using advanced dataset features like timestamp-based frame selection.
|
||||
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
||||
|
||||
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
|
||||
# Let's take one for this example
|
||||
repo_id = "lerobot/pusht"
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
pprint(repo_ids)
|
||||
|
||||
# You can easily load a dataset from a Hugging Face repository
|
||||
# Or simply explore them in your web browser directly at:
|
||||
# https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
# Let's take this one for this example
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
# We can have a look and fetch its metadata to know more about it:
|
||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
|
||||
# By instantiating just this class, you can quickly access useful information about the content and the
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
|
||||
print("Tasks:")
|
||||
print(ds_meta.tasks)
|
||||
print("Features:")
|
||||
pprint(ds_meta.features)
|
||||
|
||||
# You can also get a short summary by simply printing the object:
|
||||
print(ds_meta)
|
||||
|
||||
# You can then load the actual dataset from the hub.
|
||||
# Either load any subset of episodes:
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
|
||||
|
||||
# And see how many frames you have:
|
||||
print(f"Selected episodes: {dataset.episodes}")
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# Or simply load the entire dataset:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets/index for more information).
|
||||
print(dataset)
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
print(dataset.hf_dataset)
|
||||
|
||||
# And provides additional utilities for robotics and compatibility with Pytorch
|
||||
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
|
||||
|
||||
# Access frame indexes associated to first episode
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
|
||||
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset. Here we grab all the image frames.
|
||||
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
|
||||
|
||||
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
|
||||
# them, we convert to uint8 in range [0,255]
|
||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||
# and to channel last (h,w,c).
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
# The objects returned by the dataset are all torch.Tensors
|
||||
print(type(frames[0]))
|
||||
print(frames[0].shape)
|
||||
|
||||
# Finally, we save the frames to a mp4 video for visualization.
|
||||
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
|
||||
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
|
||||
# We can compare this shape with the information available for that feature
|
||||
pprint(dataset.features[camera_key])
|
||||
# In particular:
|
||||
print(dataset.features[camera_key]["shape"])
|
||||
# The shape is in (h, w, c) which is a more universal format.
|
||||
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
"observation.image": [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||
# timestamp, you still get a valid timestamp.
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64,c)
|
||||
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||
# PyTorch datasets.
|
||||
@@ -84,8 +140,9 @@ dataloader = torch.utils.data.DataLoader(
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32,8,c)
|
||||
print(f"{batch['action'].shape=}") # (32,64,c)
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 5, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
@@ -1,6 +1,25 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||
|
||||
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
||||
```bash
|
||||
pip install -e ".[pusht]"`
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -10,7 +29,6 @@ import gymnasium as gym
|
||||
import imageio
|
||||
import numpy
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
@@ -18,25 +36,15 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# Select your device
|
||||
device = "cuda"
|
||||
|
||||
# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):
|
||||
pretrained_policy_path = "lerobot/diffusion_pusht"
|
||||
# OR a path to a local outputs/train folder.
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
|
||||
# Check if GPU is available
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
print("GPU is available. Device set to:", device)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
|
||||
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
||||
policy.diffusion.num_inference_steps = 10
|
||||
|
||||
policy.to(device)
|
||||
|
||||
# Initialize evaluation environment to render two observation types:
|
||||
# an image of the scene and state/position of the agent. The environment
|
||||
@@ -47,7 +55,17 @@ env = gym.make(
|
||||
max_episode_steps=300,
|
||||
)
|
||||
|
||||
# Reset the policy and environmens to prepare for rollout
|
||||
# We can verify that the shapes of the features expected by the policy match the ones from the observations
|
||||
# produced by the environment
|
||||
print(policy.config.input_features)
|
||||
print(env.observation_space)
|
||||
|
||||
# Similarly, we can check that the actions produced by the policy will match the actions expected by the
|
||||
# environment
|
||||
print(policy.config.output_features)
|
||||
print(env.action_space)
|
||||
|
||||
# Reset the policy and environments to prepare for rollout
|
||||
policy.reset()
|
||||
numpy_observation, info = env.reset(seed=42)
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||
|
||||
Once you have trained a model with this script, you can try to evaluate it on
|
||||
@@ -8,72 +22,99 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
# Create a directory to store the training checkpoint.
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Number of offline training steps (we'll only do offline training for this example.)
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
training_steps = 5000
|
||||
device = torch.device("cuda")
|
||||
log_freq = 250
|
||||
def main():
|
||||
# Create a directory to store the training checkpoint.
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set up the dataset.
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
# # Select your device
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Set up the the policy.
|
||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
|
||||
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
# Number of offline training steps (we'll only do offline training for this example.)
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
training_steps = 5000
|
||||
log_freq = 1
|
||||
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
|
||||
# creating the policy:
|
||||
# - input/output shapes: to properly size the policy
|
||||
# - dataset stats: for normalization and denormalization of input/outputs
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
# Create dataloader for offline training.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=True,
|
||||
)
|
||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
|
||||
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
|
||||
# Run training loop.
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
loss = output_dict["loss"]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# We can now instantiate our policy with this config and the dataset stats.
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
|
||||
# which can differ for inputs, outputs and rewards (if there are some).
|
||||
delta_timestamps = {
|
||||
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
|
||||
}
|
||||
|
||||
# Save a policy checkpoint.
|
||||
policy.save_pretrained(output_directory)
|
||||
# In this case with the standard configuration for Diffusion Policy, it is equivalent to this:
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# We can then instantiate the dataset with these delta_timestamps configuration.
|
||||
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
|
||||
# Then we create our optimizer and dataloader for offline training.
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Run training loop.
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save a policy checkpoint.
|
||||
policy.save_pretrained(output_directory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,183 +1,274 @@
|
||||
This tutorial will explain the training script, how to use it, and particularly the use of Hydra to configure everything needed for the training run.
|
||||
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
|
||||
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--device=cpu` (`--device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||
|
||||
|
||||
## The training script
|
||||
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
|
||||
|
||||
- Loads a Hydra configuration file for the following steps (more on Hydra in a moment).
|
||||
- Makes a simulation environment.
|
||||
- Makes a dataset corresponding to that simulation environment.
|
||||
- Makes a policy.
|
||||
- Initialize/load a configuration for the following steps using.
|
||||
- Instantiates a dataset.
|
||||
- (Optional) Instantiates a simulation environment corresponding to that dataset.
|
||||
- Instantiates a policy.
|
||||
- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing.
|
||||
|
||||
## Basics of how we use Hydra
|
||||
|
||||
Explaining the ins and outs of [Hydra](https://hydra.cc/docs/intro/) is beyond the scope of this document, but here we'll share the main points you need to know.
|
||||
|
||||
First, `lerobot/configs` has a directory structure like this:
|
||||
|
||||
```
|
||||
.
|
||||
├── default.yaml
|
||||
├── env
|
||||
│ ├── aloha.yaml
|
||||
│ ├── pusht.yaml
|
||||
│ └── xarm.yaml
|
||||
└── policy
|
||||
├── act.yaml
|
||||
├── diffusion.yaml
|
||||
└── tdmpc.yaml
|
||||
```
|
||||
|
||||
**_For brevity, in the rest of this document we'll drop the leading `lerobot/configs` path. So `default.yaml` really refers to `lerobot/configs/default.yaml`._**
|
||||
|
||||
When you run the training script with
|
||||
## Overview of the configuration system
|
||||
|
||||
In the training script, the main function `train` expects a `TrainPipelineConfig` object:
|
||||
```python
|
||||
python lerobot/scripts/train.py
|
||||
# train.py
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
```
|
||||
|
||||
Hydra is set up to read `default.yaml` (via the `@hydra.main` decorator). If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. At the top of `default.yaml`, is a `defaults` section which looks likes this:
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
|
||||
```yaml
|
||||
defaults:
|
||||
- _self_
|
||||
- env: pusht
|
||||
- policy: diffusion
|
||||
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
||||
|
||||
Let's have a look at a simplified example. Amongst other attributes, the training config has the following attributes:
|
||||
```python
|
||||
@dataclass
|
||||
class TrainPipelineConfig:
|
||||
dataset: DatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
```
|
||||
in which `DatasetConfig` for example is defined as such:
|
||||
```python
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
repo_id: str
|
||||
episodes: list[int] | None = None
|
||||
video_backend: str = "pyav"
|
||||
```
|
||||
|
||||
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
||||
This creates a hierarchical relationship where, for example assuming we have a `cfg` instance of `TrainPipelineConfig`, we can access the `repo_id` value with `cfg.dataset.repo_id`.
|
||||
From the command line, we can specify this value with using a very similar syntax `--dataset.repo_id=repo/id`.
|
||||
|
||||
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
|
||||
By default, every field takes its default value specified in the dataclass. If a field doesn't have a default value, it needs to be specified either from the command line or from a config file – which path is also given in the command line (more in this below). In the example above, the `dataset` field doesn't have a default value which means it must be specified.
|
||||
|
||||
Thanks to this `defaults` section in `default.yaml`, if you want to train Diffusion Policy with PushT, you really only need to run:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py
|
||||
```
|
||||
|
||||
However, you can be more explicit and launch the exact same Diffusion Policy training on PushT with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=diffusion env=pusht
|
||||
```
|
||||
|
||||
This way of overriding defaults via the CLI is especially useful when you want to change the policy and/or environment. For instance, you can train ACT on the default Aloha environment with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=aloha
|
||||
```
|
||||
|
||||
There are two things to note here:
|
||||
- Config overrides are passed as `param_name=param_value`.
|
||||
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`.
|
||||
|
||||
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
|
||||
|
||||
## Overriding configuration parameters in the CLI
|
||||
|
||||
Now let's say that we want to train on a different task in the Aloha environment. If you look in `env/aloha.yaml` you will see something like:
|
||||
|
||||
```yaml
|
||||
# lerobot/configs/env/aloha.yaml
|
||||
env:
|
||||
task: AlohaInsertion-v0
|
||||
```
|
||||
|
||||
And if you look in `policy/act.yaml` you will see something like:
|
||||
|
||||
```yaml
|
||||
# lerobot/configs/policy/act.yaml
|
||||
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
```
|
||||
|
||||
But our Aloha environment actually supports a cube transfer task as well. To train for this task, you could manually modify the two yaml configuration files respectively.
|
||||
|
||||
First, we'd need to switch to using the cube transfer task for the ALOHA environment.
|
||||
|
||||
```diff
|
||||
# lerobot/configs/env/aloha.yaml
|
||||
env:
|
||||
- task: AlohaInsertion-v0
|
||||
+ task: AlohaTransferCube-v0
|
||||
```
|
||||
|
||||
Then, we'd also need to switch to using the cube transfer dataset.
|
||||
|
||||
```diff
|
||||
# lerobot/configs/policy/act.yaml
|
||||
-dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
+dataset_repo_id: lerobot/aloha_sim_transfer_cube_human
|
||||
```
|
||||
|
||||
Then, you'd be able to run:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=aloha
|
||||
```
|
||||
|
||||
and you'd be training and evaluating on the cube transfer task.
|
||||
|
||||
An alternative approach to editing the yaml configuration files, would be to override the defaults via the command line:
|
||||
## Specifying values from the CLI
|
||||
|
||||
Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy.type=diffusion \
|
||||
--env.type=pusht
|
||||
```
|
||||
|
||||
There's something new here. Notice the `.` delimiter used to traverse the configuration hierarchy. _But be aware that the `defaults` section is an exception. As you saw above, we didn't need to write `defaults.policy=act` in the CLI. `policy=act` was enough._
|
||||
|
||||
Putting all that knowledge together, here's the command that was used to train https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human.
|
||||
Let's break this down:
|
||||
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py)
|
||||
|
||||
Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human \
|
||||
device=cuda
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0 \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
policy=act \
|
||||
training.eval_freq=10000 \
|
||||
training.log_freq=250 \
|
||||
training.offline_steps=100000 \
|
||||
training.save_model=true \
|
||||
training.save_freq=25000 \
|
||||
eval.n_episodes=50 \
|
||||
eval.batch_size=50 \
|
||||
wandb.enable=false \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
--output_dir=outputs/train/act_aloha_insertion
|
||||
```
|
||||
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
|
||||
|
||||
There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human`, which specifies where to save the training output.
|
||||
|
||||
## Using a configuration file not in `lerobot/configs`
|
||||
|
||||
Above we discusses the our training script is set up such that Hydra looks for `default.yaml` in `lerobot/configs`. But, if you have a configuration file elsewhere in your filesystem you may use:
|
||||
|
||||
We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
|
||||
Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config-dir PARENT/PATH --config-name FILE_NAME_WITHOUT_EXTENSION
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaTransferCube-v0 \
|
||||
--output_dir=outputs/train/act_aloha_transfer
|
||||
```
|
||||
|
||||
Note: here we use regular syntax for providing CLI arguments to a Python script, not Hydra's `param_name=param_value` syntax.
|
||||
## Loading from a config file
|
||||
|
||||
As a concrete example, this becomes particularly handy when you have a folder with training outputs, and would like to re-run the training. For example, say you previously ran the training script with one of the earlier commands and have `outputs/train/my_experiment/checkpoints/pretrained_model/config.yaml`. This `config.yaml` file will have the full set of configuration parameters within it. To run the training with the same configuration again, do:
|
||||
Now, let's assume that we want to reproduce the run just above. That run has produced a `train_config.json` file in its checkpoints, which serializes the `TrainPipelineConfig` instance it used:
|
||||
```json
|
||||
{
|
||||
"dataset": {
|
||||
"repo_id": "lerobot/aloha_sim_transfer_cube_human",
|
||||
"episodes": null,
|
||||
...
|
||||
},
|
||||
"env": {
|
||||
"type": "aloha",
|
||||
"task": "AlohaTransferCube-v0",
|
||||
"fps": 50,
|
||||
...
|
||||
},
|
||||
"policy": {
|
||||
"type": "act",
|
||||
"n_obs_steps": 1,
|
||||
...
|
||||
},
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
We can then simply load the config values from this file using:
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpoints/last/pretrained_model --config-name config
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
```
|
||||
`--config_path` is also a special argument which allows to initialize the config from a local config file. It can point to a directory that contains `train_config.json` or to the config file itself directly.
|
||||
|
||||
Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
--policy.n_action_steps=80
|
||||
```
|
||||
> Note: While `--output_dir` is not required in general, in this case we need to specify it since it will otherwise take the value from the `train_config.json` (which is `outputs/train/act_aloha_transfer`). In order to prevent accidental deletion of previous run checkpoints, we raise an error if you're trying to write in an existing directory. This is not the case when resuming a run, which is what you'll learn next.
|
||||
|
||||
`--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running:
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)
|
||||
|
||||
|
||||
## Resume training
|
||||
|
||||
Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to that here.
|
||||
|
||||
Let's reuse the command from the previous run and add a few more options:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaTransferCube-v0 \
|
||||
--log_freq=25 \
|
||||
--save_freq=100 \
|
||||
--output_dir=outputs/train/run_resumption
|
||||
```
|
||||
|
||||
Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`).
|
||||
Here we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can showcase resumption. You should be able to see some logging and have a first checkpoint within 1 minute (depending on hardware). Wait for the first checkpoint to happen, you should see a line that looks like this in your terminal:
|
||||
```
|
||||
INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100
|
||||
```
|
||||
Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true
|
||||
```
|
||||
You should see from the logging that your training picks up from where it left off.
|
||||
|
||||
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
|
||||
You could double the number of steps of the previous run with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
## Outputs of a run
|
||||
In the output directory, there will be a folder called `checkpoints` with the following structure:
|
||||
```bash
|
||||
outputs/train/run_resumption/checkpoints
|
||||
├── 000100 # checkpoint_dir for training step 100
|
||||
│ ├── pretrained_model/
|
||||
│ │ ├── config.json # policy config
|
||||
│ │ ├── model.safetensors # policy weights
|
||||
│ │ └── train_config.json # train config
|
||||
│ └── training_state/
|
||||
│ ├── optimizer_param_groups.json # optimizer param groups
|
||||
│ ├── optimizer_state.safetensors # optimizer state
|
||||
│ ├── rng_state.safetensors # rng states
|
||||
│ ├── scheduler_state.json # scheduler state
|
||||
│ └── training_step.json # training step
|
||||
├── 000200
|
||||
└── last -> 000200 # symlink to the last available checkpoint
|
||||
```
|
||||
|
||||
## Fine-tuning a pre-trained policy
|
||||
|
||||
In addition to the features currently in Draccus, we've added a special `.path` argument for the policy, which allows to load a policy as you would with `PreTrainedPolicy.from_pretrained()`. In that case, `path` can be a local directory that contains a checkpoint or a repo_id pointing to a pretrained policy on the hub.
|
||||
|
||||
For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaInsertion-v0
|
||||
```
|
||||
|
||||
When doing so, keep in mind that the features of the fine-tuning dataset would have to match the input/output features of the pretrained policy.
|
||||
|
||||
## Typical logs and metrics
|
||||
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you configured your run correctly. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
```
|
||||
INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774
|
||||
```
|
||||
or evaluation log:
|
||||
```
|
||||
INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266
|
||||
```
|
||||
|
||||
These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations:
|
||||
- `smpl`: number of samples seen during training.
|
||||
- `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task.
|
||||
- `epch`: number of time all unique samples are seen (epoch).
|
||||
- `grdn`: gradient norm.
|
||||
- `∑rwrd`: compute the sum of rewards in every evaluation episode and then take an average of them.
|
||||
- `success`: average success rate of eval episodes. Reward and success are usually different except for the sparsing reward setting, where reward=1 only when the task is completed successfully.
|
||||
- `eval_s`: time to evaluate the policy in the environment, in second.
|
||||
- `updt_s`: time to update the network parameters, in second.
|
||||
- `data_s`: time to load a batch of data, in second.
|
||||
|
||||
Some metrics are useful for initial performance profiling. For example, if you find the current GPU utilization is low via the `nvidia-smi` command and `data_s` sometimes is too high, you may need to modify batch size or number of dataloading workers to accelerate dataloading. We also recommend [pytorch profiler](https://github.com/huggingface/lerobot?tab=readme-ov-file#improve-your-code-with-profiling) for detailed performance probing.
|
||||
|
||||
## In short
|
||||
|
||||
We'll summarize here the main use cases to remember from this tutorial.
|
||||
|
||||
#### Train a policy from scratch – CLI
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \ # <- select 'act' policy
|
||||
--env.type=pusht \ # <- select 'pusht' environment
|
||||
--dataset.repo_id=lerobot/pusht # <- train on this dataset
|
||||
```
|
||||
|
||||
#### Train a policy from scratch - config file + CLI
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=path/to/pretrained_model \ # <- can also be a repo_id
|
||||
--policy.n_action_steps=80 # <- you may still override values
|
||||
```
|
||||
|
||||
#### Resume/continue a training run
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=checkpoint/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000 # <- you can change some training parameters
|
||||
```
|
||||
|
||||
#### Fine-tuning
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaInsertion-v0
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||
Now that you know the basics of how to train a policy, you might want to know how to apply this knowledge to actual robots, or how to record your own datasets and train policies on your specific task?
|
||||
If that's the case, head over to the next tutorial [`7_get_started_with_real_robot.md`](./7_get_started_with_real_robot.md).
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
|
||||
```
|
||||
|
||||
Please, head on over to our [advanced tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md) to learn more.
|
||||
|
||||
Or in the meantime, happy coding! 🤗
|
||||
Or in the meantime, happy training! 🤗
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
This tutorial explains how to resume a training run that you've started with the training script. If you don't know how our training script and configuration system works, please read [4_train_policy_with_script.md](./4_train_policy_with_script.md) first.
|
||||
|
||||
## Basic training resumption
|
||||
|
||||
Let's consider the example of training ACT for one of the ALOHA tasks. Here's a command that can achieve that:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/run_resumption \
|
||||
policy=act \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0 \
|
||||
training.log_freq=25 \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=100
|
||||
```
|
||||
|
||||
Here we're using the default dataset and environment for ACT, and we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can test resumption. You should be able to see some logging and have a first checkpoint within 1 minute. Please interrupt the training after the first checkpoint.
|
||||
|
||||
To resume, all that we have to do is run the training script, providing the run directory, and the resume option:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/run_resumption \
|
||||
resume=true
|
||||
```
|
||||
|
||||
You should see from the logging that your training picks up from where it left off.
|
||||
|
||||
Note that with `resume=true`, the configuration file from the last checkpoint in the training output directory is loaded. So it doesn't matter that we haven't provided all the other configuration parameters from our previous command (although there may be warnings to notify you that your command has a different configuration than than the checkpoint).
|
||||
|
||||
---
|
||||
|
||||
Now you should know how to resume your training run in case it gets interrupted or you want to extend a finished training run.
|
||||
|
||||
Happy coding! 🤗
|
||||
@@ -11,7 +11,7 @@ This tutorial will guide you through the process of setting up and training a ne
|
||||
|
||||
By following these steps, you'll be able to replicate tasks like picking up a Lego block and placing it in a bin with a high success rate, as demonstrated in [this video](https://x.com/RemiCadene/status/1814680760592572934).
|
||||
|
||||
Although this tutorial is general and can be easily adapted to various types of robots by changing the configuration, it is specifically based on the [Koch v1.1](https://github.com/jess-moss/koch-v1-1), an affordable robot. The Koch v1.1 consists of a leader arm and a follower arm, each with 6 motors. It can work with one or several cameras to record the scene, which serve as visual sensors for the robot.
|
||||
This tutorial is specifically made for the affordable [Koch v1.1](https://github.com/jess-moss/koch-v1-1) robot, but it contains additional information to be easily adapted to various types of robots like [Aloha bimanual robot](https://aloha-2.github.io) by changing some configurations. The Koch v1.1 consists of a leader arm and a follower arm, each with 6 motors. It can work with one or several cameras to record the scene, which serve as visual sensors for the robot.
|
||||
|
||||
During the data collection phase, you will control the follower arm by moving the leader arm. This process is known as "teleoperation." This technique is used to collect robot trajectories. Afterward, you'll train a neural network to imitate these trajectories and deploy the network to enable your robot to operate autonomously.
|
||||
|
||||
@@ -29,16 +29,28 @@ For a visual walkthrough of the assembly process, you can refer to [this video t
|
||||
|
||||
## 2. Configure motors, calibrate arms, teleoperate your Koch v1.1
|
||||
|
||||
First, install the additional dependencies required for Koch v1.1 by running one of the following commands.
|
||||
First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands (make sure gcc is installed).
|
||||
|
||||
Using `pip`:
|
||||
```bash
|
||||
pip install -e ".[koch]"
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
Or using `poetry`:
|
||||
Using `poetry`:
|
||||
```bash
|
||||
poetry install --sync --extras "koch"
|
||||
poetry sync --extras "dynamixel"
|
||||
```
|
||||
|
||||
Using `uv`:
|
||||
```bash
|
||||
uv sync --extra "dynamixel"
|
||||
```
|
||||
|
||||
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
|
||||
```bash
|
||||
conda install -c conda-forge ffmpeg
|
||||
pip uninstall opencv-python
|
||||
conda install -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
You are now ready to plug the 5V power supply to the motor bus of the leader arm (the smaller one) since all its motors only require 5V.
|
||||
@@ -47,36 +59,65 @@ Then plug the 12V power supply to the motor bus of the follower arm. It has two
|
||||
|
||||
Finally, connect both arms to your computer via USB. Note that the USB doesn't provide any power, and both arms need to be plugged in with their associated power supply to be detected by your computer.
|
||||
|
||||
*Copy pasting python code*
|
||||
Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
|
||||
|
||||
In the upcoming sections, you'll learn about our classes and functions by running some python code, in an interactive session, or by copy-pasting it in a python file. If this is your first time using the tutorial., we highly recommend going through these steps to get deeper intuition about how things work. Once you're more familiar, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
||||
If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--robot-overrides '~cameras' # do not instantiate the cameras
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
It will automatically:
|
||||
1. Detect and help you correct any motor configuration issues.
|
||||
2. Identify any missing calibrations and initiate the calibration procedure.
|
||||
3. Connect the robot and start teleoperation.
|
||||
1. Identify any missing calibrations and initiate the calibration procedure.
|
||||
2. Connect the robot and start teleoperation.
|
||||
|
||||
### a. Control your motors with DynamixelMotorsBus
|
||||
|
||||
You can use the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) to communicate with the motors connected as a chain to the corresponding USB bus. This class leverages the Python [Dynamixel SDK](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20) to facilitate reading from and writing to the motors.
|
||||
|
||||
**First Configuration of your motors**
|
||||
|
||||
You will need to unplug each motor in turn and run a command the identify the motor. The motor will save its own identification, so you only need to do this once. Start by unplugging all of the motors.
|
||||
|
||||
Do the Leader arm first, as all of its motors are of the same type. Plug in your first motor on your leader arm and run this script to set its ID to 1.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand dynamixel \
|
||||
--model xl330-m288 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Then unplug your first motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand dynamixel \
|
||||
--model xl330-m288 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6.
|
||||
|
||||
The process for the follower arm is almost the same, but the follower arm has two types of motors. For the first two motors, make sure you set the model to `xl430-w250`. _Important: configuring follower motors requires plugging and unplugging power. Make sure you use the 5V power for the XL330s and the 12V power for the XL430s!_
|
||||
|
||||
After all of your motors are configured properly, you're ready to plug them all together in a daisy-chain as shown in the original video.
|
||||
|
||||
**Instantiate the DynamixelMotorsBus**
|
||||
|
||||
To begin, create two instances of the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py), one for each arm, using their corresponding USB ports (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`).
|
||||
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/motors/dynamixel.py
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the DynamixelMotorsBus.
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
@@ -88,7 +129,7 @@ Reconnect the usb cable.
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the DynamixelMotorsBus.
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
@@ -98,10 +139,10 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running this command with your ports:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
sudo chmod 666 /dev/tty.usbmodem575E0032081
|
||||
sudo chmod 666 /dev/tty.usbmodem575E0031751
|
||||
```
|
||||
|
||||
*Listing and Configuring Motors*
|
||||
@@ -110,13 +151,11 @@ Next, you'll need to list the motors for each arm, including their name, index,
|
||||
|
||||
To assign indices to the motors, run this code in an interactive Python session. Replace the `port` values with the ones you identified earlier:
|
||||
```python
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
leader_port = "/dev/tty.usbmodem575E0031751"
|
||||
follower_port = "/dev/tty.usbmodem575E0032081"
|
||||
|
||||
leader_arm = DynamixelMotorsBus(
|
||||
port=leader_port,
|
||||
leader_config = DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl330-m077"),
|
||||
@@ -128,8 +167,8 @@ leader_arm = DynamixelMotorsBus(
|
||||
},
|
||||
)
|
||||
|
||||
follower_arm = DynamixelMotorsBus(
|
||||
port=follower_port,
|
||||
follower_config = DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
@@ -140,42 +179,57 @@ follower_arm = DynamixelMotorsBus(
|
||||
"gripper": (6, "xl330-m288"),
|
||||
},
|
||||
)
|
||||
|
||||
leader_arm = DynamixelMotorsBus(leader_config)
|
||||
follower_arm = DynamixelMotorsBus(follower_config)
|
||||
```
|
||||
|
||||
*Updating the YAML Configuration File*
|
||||
IMPORTANTLY: Now that you have your ports, update [`KochRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("koch")
|
||||
@dataclass
|
||||
class KochRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/koch"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
Next, update the port values in the YAML configuration file for the Koch robot at [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml) with the ports you've identified:
|
||||
```yaml
|
||||
[...]
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751 # <- Update
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
shoulder_lift: [2, "xl330-m077"]
|
||||
elbow_flex: [3, "xl330-m077"]
|
||||
wrist_flex: [4, "xl330-m077"]
|
||||
wrist_roll: [5, "xl330-m077"]
|
||||
gripper: [6, "xl330-m077"]
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081 # <- Update
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
shoulder_lift: [2, "xl430-w250"]
|
||||
elbow_flex: [3, "xl330-m288"]
|
||||
wrist_flex: [4, "xl330-m288"]
|
||||
wrist_roll: [5, "xl330-m288"]
|
||||
gripper: [6, "xl330-m288"]
|
||||
[...]
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0085511", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl330-m077"],
|
||||
"shoulder_lift": [2, "xl330-m077"],
|
||||
"elbow_flex": [3, "xl330-m077"],
|
||||
"wrist_flex": [4, "xl330-m077"],
|
||||
"wrist_roll": [5, "xl330-m077"],
|
||||
"gripper": [6, "xl330-m077"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
"elbow_flex": [3, "xl330-m288"],
|
||||
"wrist_flex": [4, "xl330-m288"],
|
||||
"wrist_roll": [5, "xl330-m288"],
|
||||
"gripper": [6, "xl330-m288"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
This configuration file is used to instantiate your robot across all scripts. We'll cover how this works later on.
|
||||
|
||||
**Connect and Configure your Motors**
|
||||
|
||||
Before you can start using your motors, you'll need to configure them to ensure proper communication. When you first connect the motors, the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) automatically detects any mismatch between the current motor indices (factory set to `1`) and the specified indices (e.g., `1, 2, 3, 4, 5, 6`). This triggers a configuration procedure that requires you to unplug the power cord and motors, then reconnect each motor sequentially, starting from the one closest to the bus.
|
||||
@@ -298,32 +352,37 @@ Alternatively, you can unplug the power cord, which will automatically disable t
|
||||
|
||||
*/!\ Warning*: These motors tend to overheat, especially under torque or if left plugged in for too long. Unplug after use.
|
||||
|
||||
### b. Teleoperate your Koch v1.1 with KochRobot
|
||||
### b. Teleoperate your Koch v1.1 with ManipulatorRobot
|
||||
|
||||
**Instantiate the KochRobot**
|
||||
**Instantiate the ManipulatorRobot**
|
||||
|
||||
Before you can teleoperate your robot, you need to instantiate the [`KochRobot`](../lerobot/common/robot_devices/robots/koch.py) using the previously defined `leader_arm` and `follower_arm`.
|
||||
Before you can teleoperate your robot, you need to instantiate the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) using the previously defined `leader_config` and `follower_config`.
|
||||
|
||||
For the Koch robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_arm}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_arm, "right": right_leader_arm},`. Same thing for the follower arms.
|
||||
For the Koch v1.1 robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_config}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_config, "right": right_leader_config},`. Same thing for the follower arms.
|
||||
|
||||
You also need to provide a path to a calibration file, such as `calibration_path=".cache/calibration/koch.pkl"`. More on this in the next section.
|
||||
|
||||
Run the following code to instantiate your Koch robot:
|
||||
Run the following code to instantiate your manipulator robot:
|
||||
```python
|
||||
from lerobot.common.robot_devices.robots.koch import KochRobot
|
||||
from lerobot.common.robot_devices.robots.configs import KochRobotConfig
|
||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||
|
||||
robot = KochRobot(
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_path=".cache/calibration/koch.pkl",
|
||||
robot_config = KochRobotConfig(
|
||||
leader_arms={"main": leader_config},
|
||||
follower_arms={"main": follower_config},
|
||||
cameras={}, # We don't use any camera for now
|
||||
)
|
||||
robot = ManipulatorRobot(robot_config)
|
||||
```
|
||||
|
||||
**Calibrate and Connect the KochRobot**
|
||||
The `KochRobotConfig` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger.
|
||||
|
||||
Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one Koch robot to work on another.
|
||||
For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `AlohaRobotConfig` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected for Aloha.
|
||||
|
||||
When you connect your robot for the first time, the [`KochRobot`](../lerobot/common/robot_devices/robots/koch.py) will detect if the calibration file is missing and trigger the calibration procedure. During this process, you will be guided to move each arm to three different positions.
|
||||
**Calibrate and Connect the ManipulatorRobot**
|
||||
|
||||
Next, you'll need to calibrate your Koch robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one Koch robot to work on another.
|
||||
|
||||
When you connect your robot for the first time, the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) will detect if the calibration file is missing and trigger the calibration procedure. During this process, you will be guided to move each arm to three different positions.
|
||||
|
||||
Here are the positions you'll move the follower arm to:
|
||||
|
||||
@@ -339,7 +398,7 @@ And here are the corresponding positions for the leader arm:
|
||||
|
||||
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
||||
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to mesure if the values changed negatively or positively.
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
|
||||
|
||||
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
|
||||
|
||||
@@ -354,27 +413,26 @@ The output will look like this:
|
||||
```
|
||||
Connecting main follower arm
|
||||
Connecting main leader arm
|
||||
Missing calibration file '.cache/calibration/koch.pkl'. Starting calibration procedure.
|
||||
|
||||
Running calibration of main follower...
|
||||
|
||||
Missing calibration file '.cache/calibration/koch/main_follower.json'
|
||||
Running calibration of koch main follower...
|
||||
Move arm to zero position
|
||||
[...]
|
||||
Move arm to rotated position
|
||||
[...]
|
||||
Move arm to rest position
|
||||
[...]
|
||||
Calibration is done! Saving calibration file '.cache/calibration/koch/main_follower.json'
|
||||
|
||||
Running calibration of main leader...
|
||||
|
||||
Missing calibration file '.cache/calibration/koch/main_leader.json'
|
||||
Running calibration of koch main leader...
|
||||
Move arm to zero position
|
||||
[...]
|
||||
Move arm to rotated position
|
||||
[...]
|
||||
Move arm to rest position
|
||||
[...]
|
||||
|
||||
Calibration is done! Saving calibration file '.cache/calibration/koch.pkl'
|
||||
Calibration is done! Saving calibration file '.cache/calibration/koch/main_leader.json'
|
||||
```
|
||||
|
||||
*Verifying Calibration*
|
||||
@@ -414,7 +472,7 @@ for _ in tqdm.tqdm(range(seconds*frequency)):
|
||||
|
||||
*Using `teleop_step` for Teleoperation*
|
||||
|
||||
Alternatively, you can teleoperate the robot using the `teleop_step` method from [`KochRobot`](../lerobot/common/robot_devices/robots/koch.py).
|
||||
Alternatively, you can teleoperate the robot using the `teleop_step` method from [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py).
|
||||
|
||||
Run this code to teleoperate:
|
||||
```python
|
||||
@@ -565,9 +623,11 @@ Note: Some cameras may take a few seconds to warm up, and the first frame might
|
||||
|
||||
Finally, run this code to instantiate and connectyour camera:
|
||||
```python
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
camera = OpenCVCamera(camera_index=0)
|
||||
config = OpenCVCameraConfig(camera_index=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
|
||||
@@ -589,7 +649,7 @@ uint8
|
||||
|
||||
With certain camera, you can also specify additional parameters like frame rate, resolution, and color mode during instantiation. For instance:
|
||||
```python
|
||||
camera = OpenCVCamera(camera_index=0, fps=30, width=640, height=480)
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=30, width=640, height=480)
|
||||
```
|
||||
|
||||
If the provided arguments are not compatible with the camera, an exception will be raised.
|
||||
@@ -603,18 +663,20 @@ camera.disconnect()
|
||||
|
||||
**Instantiate your robot with cameras**
|
||||
|
||||
Additionaly, you can set up your robot to work with your cameras.
|
||||
Additionally, you can set up your robot to work with your cameras.
|
||||
|
||||
Modify the following Python code with the appropriate camera names and configurations:
|
||||
```python
|
||||
robot = KochRobot(
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_path=".cache/calibration/koch.pkl",
|
||||
cameras={
|
||||
"laptop": OpenCVCamera(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCamera(1, fps=30, width=640, height=480),
|
||||
},
|
||||
robot = ManipulatorRobot(
|
||||
KochRobotConfig(
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
cameras={
|
||||
"laptop": OpenCVCameraConfig(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
|
||||
},
|
||||
)
|
||||
)
|
||||
robot.connect()
|
||||
```
|
||||
@@ -638,39 +700,20 @@ torch.Size([3, 480, 640])
|
||||
255
|
||||
```
|
||||
|
||||
Also, update the following lines of the yaml file for Koch robot [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml) with the names and configurations of your cameras:
|
||||
```yaml
|
||||
[...]
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
```
|
||||
### d. Use `control_robot.py` and our `teleoperate` function
|
||||
|
||||
This file is used to instantiate your robot in all our scripts. We will explain how this works in the next section.
|
||||
|
||||
### d. Use `koch.yaml` and our `teleoperate` function
|
||||
|
||||
Instead of manually running the python code in a terminal window, you can use [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to instantiate your robot by providing the path to the robot yaml file (e.g. [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml)) and control your robot with various modes as explained next.
|
||||
Instead of manually running the python code in a terminal window, you can use [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to instantiate your robot by providing the robot configurations via command line and control your robot with various modes as explained next.
|
||||
|
||||
Try running this code to teleoperate your robot (if you dont have a camera, keep reading):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtRfoll: 0.19 (5239.0hz)
|
||||
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtWfoll: 0.19 (5239.0hz)
|
||||
```
|
||||
|
||||
It contains
|
||||
@@ -680,21 +723,12 @@ It contains
|
||||
- `dtRlead: 4.93 (203.0hz)` which is the number of milliseconds it took to read the position of the leader arm using `leader_arm.read("Present_Position")`.
|
||||
- `dtWfoll: 0.22 (4446.9hz)` which is the number of milliseconds it took to set a new goal position for the follower arm using `follower_arm.write("Goal_position", leader_pos)` ; note that writing is done asynchronously so it takes less time than reading.
|
||||
|
||||
Note: you can override any entry in the yaml file using `--robot-overrides` and the [hydra.cc](https://hydra.cc/docs/advanced/override_grammar/basic) syntax. If needed, you can override the ports like this:
|
||||
Importantly: If you don't have any camera, you can remove them dynamically with this [draccus](https://github.com/dlwh/draccus) syntax `--robot.cameras='{}'`:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--robot-overrides \
|
||||
leader_arms.main.port=/dev/tty.usbmodem575E0031751 \
|
||||
follower_arms.main.port=/dev/tty.usbmodem575E0032081
|
||||
```
|
||||
|
||||
Importantly: If you don't have any camera, you can remove them dynamically with this [hydra.cc](https://hydra.cc/docs/advanced/override_grammar/basic) syntax `'~cameras'`:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--robot-overrides \
|
||||
'~cameras'
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
We advise to create a new yaml file when the command becomes too long.
|
||||
@@ -730,55 +764,56 @@ for _ in range(record_time_s * fps):
|
||||
|
||||
Importantly, many utilities are still missing. For instance, if you have cameras, you will need to save the images on disk to not go out of RAM, and to do so in threads to not slow down communication with your robot. Also, you will need to store your data in a format optimized for training and web sharing like [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py). More on this in the next section.
|
||||
|
||||
### a. Use `koch.yaml` and the `record` function
|
||||
### a. Use the `record` function
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to achieve efficient data recording. It encompasses many recording utilities:
|
||||
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of recording.
|
||||
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording.
|
||||
2. Video streams from cameras are displayed in window so that you can verify them.
|
||||
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--push-to-hub 0` is provided).
|
||||
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again. You can also use `--force-override 1` to start recording from scratch.
|
||||
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided).
|
||||
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch.
|
||||
5. Set the flow of data recording using command line arguments:
|
||||
- `--warmup-time-s` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
||||
- `--episode-time-s` defines the number of seconds for data recording for each episode (60 seconds by default).
|
||||
- `--reset-time-s` defines the number of seconds for resetting the environment after each episode (60 seconds by default).
|
||||
- `--num-episodes` defines the number of episodes to record (50 by default).
|
||||
- `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
||||
- `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
|
||||
- `--control.reset_time_s=60` defines the number of seconds for resetting the environment after each episode (60 seconds by default).
|
||||
- `--control.num_episodes=50` defines the number of episodes to record (50 by default).
|
||||
6. Control the flow during data recording using keyboard keys:
|
||||
- Press right arrow `->` at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
|
||||
- Press left arrow `<-` at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
|
||||
- Press escape `ESC` at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
|
||||
7. Similarly to `teleoperate`, you can also use `--robot-path` and `--robot-overrides` to specify your robots.
|
||||
7. Similarly to `teleoperate`, you can also use the command line to override anything.
|
||||
|
||||
Before trying `record`, if you want to push your dataset to the hub, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
Also, store your Hugging Face repositery name in a variable (e.g. `cadene` or `lerobot`). For instance, run this to use your Hugging Face user name as repositery:
|
||||
Also, store your Hugging Face repository name in a variable (e.g. `cadene` or `lerobot`). For instance, run this to use your Hugging Face user name as repository:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
If you don't want to push to hub, use `--push-to-hub 0`.
|
||||
If you don't want to push to hub, use `--control.push_to_hub=false`.
|
||||
|
||||
Now run this to record 2 episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/koch_test \
|
||||
--tags tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 30 \
|
||||
--num-episodes 2
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=record \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/koch_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
This will write your dataset locally to `{root}/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
|
||||
This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
|
||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
Remember to add `--robot-overrides '~cameras'` if you don't have any cameras and you still use the default `koch.yaml` configuration.
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 15:02:58 ol_robot.py:219 dt:33.34 (30.0hz) dtRlead: 5.06 (197.5hz) dtWfoll: 0.25 (3963.7hz) dtRfoll: 6.22 (160.7hz) dtRlaptop: 32.57 (30.7hz) dtRphone: 33.84 (29.5hz)
|
||||
@@ -790,8 +825,8 @@ It contains:
|
||||
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
|
||||
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
|
||||
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
|
||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchrously.
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchrously.
|
||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously.
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||
|
||||
Troubleshooting:
|
||||
- On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda:
|
||||
@@ -811,7 +846,7 @@ At the end of data recording, your dataset will be uploaded on your Hugging Face
|
||||
echo https://huggingface.co/datasets/${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
### b. Advices for recording dataset
|
||||
### b. Advice for recording dataset
|
||||
|
||||
Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings.
|
||||
|
||||
@@ -826,10 +861,11 @@ In the coming months, we plan to release a foundational model for robotics. We a
|
||||
You can visualize your dataset by running:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
Note: You might need to add `--local-files-only 1` if your dataset was not uploaded to hugging face hub.
|
||||
|
||||
This will launch a local web server that looks like this:
|
||||
<div style="text-align:center;">
|
||||
<img src="../media/tutorial/visualize_dataset_html.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="100%">
|
||||
@@ -841,12 +877,12 @@ A useful feature of [`lerobot/scripts/control_robot.py`](../lerobot/scripts/cont
|
||||
|
||||
To replay the first episode of the dataset you just recorded, run the following command:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/koch_test \
|
||||
--episode 0
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/koch_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
|
||||
@@ -857,54 +893,20 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/koch_test \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
hydra.run.dir=outputs/train/act_koch_test \
|
||||
hydra.job.name=act_koch_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/koch_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_koch_test \
|
||||
--job_name=act_koch_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/koch_test`.
|
||||
2. We provided the policy with `policy=act_koch_real`. This loads configurations from [`lerobot/configs/policy/act_koch_real.yaml`](../lerobot/configs/policy/act_koch_real.yaml). Importantly, this policy uses 2 cameras as input `laptop` and `phone`. If your dataset has different cameras, update the yaml file to account for it in the following parts:
|
||||
```yaml
|
||||
...
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
...
|
||||
input_shapes:
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
...
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
...
|
||||
```
|
||||
3. We provided an environment as argument with `env=koch_real`. This loads configurations from [`lerobot/configs/env/koch_real.yaml`](../lerobot/configs/env/koch_real.yaml). It looks like
|
||||
```yaml
|
||||
fps: 30
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
```
|
||||
It should match your dataset (e.g. `fps: 30`) and your robot (e.g. `state_dim: 6` and `action_dim: 6`). We are still working on simplifying this in future versions of `lerobot`.
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
@@ -925,7 +927,7 @@ huggingface-cli upload ${HF_USER}/act_koch_test_${CKPT} \
|
||||
|
||||
## 5. Evaluate your policy
|
||||
|
||||
Now that you have a policy checkpoint, you can easily control your robot with it using methods from [`KochRobot`](../lerobot/common/robot_devices/robots/koch.py) and the policy.
|
||||
Now that you have a policy checkpoint, you can easily control your robot with it using methods from [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) and the policy.
|
||||
|
||||
Try this code for running inference for 60 seconds at 30 fps:
|
||||
```python
|
||||
@@ -968,36 +970,36 @@ for _ in range(inference_time_s * fps):
|
||||
busy_wait(1 / fps - dt_s)
|
||||
```
|
||||
|
||||
### a. Use `koch.yaml` and our `record` function
|
||||
### a. Use our `record` function
|
||||
|
||||
Ideally, when controlling your robot with your neural network, you would want to record evaluation episodes and to be able to visualize them later on, or even train on them like in Reinforcement Learning. This pretty much corresponds to recording a new dataset but with a neural network providing the actions instead of teleoperation.
|
||||
|
||||
To this end, you can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_koch_test \
|
||||
--tags tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 30 \
|
||||
--num-episodes 10 \
|
||||
-p outputs/train/act_koch_test/checkpoints/last/pretrained_model
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/eval_act_koch_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_koch_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_koch_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_koch_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_koch_test`).
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_koch_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_koch_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_koch_test`).
|
||||
|
||||
### b. Visualize evaluation afterwards
|
||||
|
||||
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_koch_test
|
||||
--repo-id ${HF_USER}/eval_act_koch_test
|
||||
```
|
||||
|
||||
## 6. Next step
|
||||
|
||||
161
examples/8_use_stretch.md
Normal file
161
examples/8_use_stretch.md
Normal file
@@ -0,0 +1,161 @@
|
||||
This tutorial explains how to use [Stretch 3](https://hello-robot.com/stretch-3-product) with LeRobot.
|
||||
|
||||
## Setup
|
||||
|
||||
Familiarize yourself with Stretch by following its [tutorials](https://docs.hello-robot.com/0.3/getting_started/hello_robot/) (recommended).
|
||||
|
||||
To use LeRobot on Stretch, 3 options are available:
|
||||
- [tethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup)
|
||||
- [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup)
|
||||
- ssh directly into Stretch (you will first need to install and configure openssh-server on stretch using one of the two above setups)
|
||||
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
On Stretch's CLI, follow these steps:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Comment out these lines in `~/.profile` (this can mess up paths used by conda and ~/.local/bin should already be in your PATH)
|
||||
```
|
||||
# set PATH so it includes user's private bin if it exists
|
||||
if [ -d "$HOME/.local/bin" ] ; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
fi
|
||||
```
|
||||
|
||||
3. Restart shell or `source ~/.bashrc`
|
||||
|
||||
4. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
5. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
6. Install LeRobot with stretch dependencies:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[stretch]"
|
||||
```
|
||||
|
||||
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
|
||||
|
||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
7. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
|
||||
```bash
|
||||
stretch_system_check.py
|
||||
```
|
||||
|
||||
> **Note:** You may need to free the "robot process" after booting Stretch by running `stretch_free_robot_process.py`. For more info this Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#turning-off-gamepad-teleoperation).
|
||||
|
||||
You should get something like this:
|
||||
```bash
|
||||
For use with S T R E T C H (R) from Hello Robot Inc.
|
||||
---------------------------------------------------------------------
|
||||
|
||||
Model = Stretch 3
|
||||
Tool = DexWrist 3 w/ Gripper
|
||||
Serial Number = stretch-se3-3054
|
||||
|
||||
---- Checking Hardware ----
|
||||
[Pass] Comms are ready
|
||||
[Pass] Actuators are ready
|
||||
[Warn] Sensors not ready (IMU AZ = -10.19 out of range -10.1 to -9.5)
|
||||
[Pass] Battery voltage is 13.6 V
|
||||
|
||||
---- Checking Software ----
|
||||
[Pass] Ubuntu 22.04 is ready
|
||||
[Pass] All APT pkgs are setup correctly
|
||||
[Pass] Firmware is up-to-date
|
||||
[Pass] Python pkgs are up-to-date
|
||||
[Pass] ROS2 Humble is ready
|
||||
```
|
||||
|
||||
## Teleoperate, record a dataset and run a policy
|
||||
|
||||
**Calibrate (Optional)**
|
||||
Before operating Stretch, you need to [home](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#homing) it first. Be mindful about giving Stretch some space as this procedure will move the robot's arm and gripper. Now run this command:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=stretch \
|
||||
--control.type=calibrate
|
||||
```
|
||||
This is equivalent to running `stretch_robot_home.py`
|
||||
|
||||
> **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
|
||||
|
||||
**Teleoperate**
|
||||
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
||||
|
||||
Now try out teleoperation (see above documentation to learn about the gamepad controls):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=stretch \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
This is essentially the same as running `stretch_gamepad_teleop.py`
|
||||
|
||||
**Record a dataset**
|
||||
Once you're familiar with the gamepad controls and after a bit of practice, you can try to record your first dataset with Stretch.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record one episode:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=stretch \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/stretch_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
> **Note:** If you're using ssh to connect to Stretch and run this script, you won't be able to visualize its cameras feed (though they will still be recording). To see the cameras stream, use [tethered](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup) or [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup).
|
||||
|
||||
**Replay an episode**
|
||||
Now try to replay this episode (make sure the robot's initial position is the same):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=stretch \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/stretch_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
Follow [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) to train a policy on your data and run inference on your robot. You will need to adapt the code for Stretch.
|
||||
|
||||
> TODO(rcadene, aliberts): Add already setup environment and policy yaml configuration files
|
||||
|
||||
If you need help, please reach out on Discord in the channel `#stretch3-mobile-arm`.
|
||||
181
examples/9_use_aloha.md
Normal file
181
examples/9_use_aloha.md
Normal file
@@ -0,0 +1,181 @@
|
||||
This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.trossenrobotics.com/aloha-stationary) with LeRobot.
|
||||
|
||||
## Setup
|
||||
|
||||
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
|
||||
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
|
||||
```
|
||||
|
||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
|
||||
**/!\ FOR SAFETY, READ THIS /!\**
|
||||
Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly:
|
||||
1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms,
|
||||
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
|
||||
|
||||
By running the following code, you can start your first **SAFE** teleoperation:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
--robot.max_relative_target=5 \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/common/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
--robot.max_relative_target=null \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with Aloha.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
--robot.max_relative_target=null \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/aloha_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/aloha_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/aloha_test
|
||||
```
|
||||
|
||||
## Replay an episode
|
||||
|
||||
**/!\ FOR SAFETY, READ THIS /!\**
|
||||
Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot.max_relative_target=5` to your command line as explained above.
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
--robot.max_relative_target=null \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/aloha_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/aloha_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_aloha_test \
|
||||
--job_name=act_aloha_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/eval_act_aloha_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_aloha_test/checkpoints/last/pretrained_model \
|
||||
--control.num_image_writer_processes=1
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
|
||||
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.
|
||||
@@ -1,7 +1,21 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
|
||||
transforms are applied to the observation images before they are returned in the dataset's __getitem__.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -10,17 +24,17 @@ from torchvision.transforms import ToPILImage, v2
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset_repo_id = "lerobot/aloha_static_tape"
|
||||
dataset_repo_id = "lerobot/aloha_static_screw_driver"
|
||||
|
||||
# Create a LeRobotDataset with no transformations
|
||||
dataset = LeRobotDataset(dataset_repo_id)
|
||||
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
|
||||
|
||||
# Get the index of the first observation in the first episode
|
||||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
|
||||
# Get the frame corresponding to the first camera
|
||||
frame = dataset[first_idx][dataset.camera_keys[0]]
|
||||
frame = dataset[first_idx][dataset.meta.camera_keys[0]]
|
||||
|
||||
|
||||
# Define the transformations
|
||||
@@ -28,15 +42,16 @@ transforms = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 1.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 1.5)),
|
||||
v2.ColorJitter(hue=(-0.1, 0.1)),
|
||||
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
|
||||
]
|
||||
)
|
||||
|
||||
# Create another LeRobotDataset with the defined transformations
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
|
||||
|
||||
# Create a directory to store output images
|
||||
output_dir = Path("outputs/image_transforms")
|
||||
@@ -1,87 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Change the seed to match what PushT eval uses
|
||||
# (to avoid evaluating on seeds used for generating the training data).
|
||||
seed: 100000
|
||||
# Change the dataset repository to the PushT one.
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
override_dataset_stats:
|
||||
observation.image:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.image: mean_std
|
||||
# Use min_max normalization just because it's more standard.
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
# Use min_max normalization just because it's more standard.
|
||||
action: min_max
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -1,70 +0,0 @@
|
||||
In this tutorial we will learn how to adapt a policy configuration to be compatible with a new environment and dataset. As a concrete example, we will adapt the default configuration for ACT to be compatible with the PushT environment and dataset.
|
||||
|
||||
If you haven't already read our tutorial on the [training script and configuration tooling](../4_train_policy_with_script.md) please do so prior to tackling this tutorial.
|
||||
|
||||
Let's get started!
|
||||
|
||||
Suppose we want to train ACT for PushT. Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
|
||||
```
|
||||
|
||||
We need to adapt the parameters of the ACT policy configuration to the PushT environment. The most important ones are the image keys.
|
||||
|
||||
ALOHA's datasets and environments typically use a variable number of cameras. In `lerobot/configs/policy/act.yaml` you may notice two relevant sections. Here we show you the minimal diff needed to adjust to PushT:
|
||||
|
||||
```diff
|
||||
override_dataset_stats:
|
||||
- observation.images.top:
|
||||
+ observation.image:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
policy:
|
||||
input_shapes:
|
||||
- observation.images.top: [3, 480, 640]
|
||||
+ observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
input_normalization_modes:
|
||||
- observation.images.top: mean_std
|
||||
+ observation.image: mean_std
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
```
|
||||
|
||||
Here we've accounted for the following:
|
||||
- PushT uses "observation.image" for its image key.
|
||||
- PushT provides smaller images.
|
||||
|
||||
_Side note: technically we could override these via the CLI, but with many changes it gets a bit messy, and we also have a bit of a challenge in that we're using `.` in our observation keys which is treated by Hydra as a hierarchical separator_.
|
||||
|
||||
For your convenience, we provide [`act_pusht.yaml`](./act_pusht.yaml) in this directory. It contains the diff above, plus some other (optional) ones that are explained within. Please copy it into `lerobot/configs/policy` with:
|
||||
|
||||
```bash
|
||||
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/act_pusht.yaml
|
||||
```
|
||||
|
||||
(remember from a [previous tutorial](../4_train_policy_with_script.md) that Hydra will look in the `lerobot/configs` directory). Now try running the following.
|
||||
|
||||
<!-- Note to contributor: are you changing this command? Note that it's tested in `Makefile`, so change it there too! -->
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act_pusht env=pusht
|
||||
```
|
||||
|
||||
Notice that this is much the same as the command that failed at the start of the tutorial, only:
|
||||
- Now we are using `policy=act_pusht` to point to our new configuration file.
|
||||
- We can drop `dataset_repo_id=lerobot/pusht` as the change is incorporated in our new configuration file.
|
||||
|
||||
Hurrah! You're now training ACT for the PushT environment.
|
||||
|
||||
---
|
||||
|
||||
The bottom line of this tutorial is that when training policies for different environments and datasets you will need to understand what parts of the policy configuration are specific to those and make changes accordingly.
|
||||
|
||||
Happy coding! 🤗
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
|
||||
|
||||
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
|
||||
@@ -9,82 +23,82 @@ on the target environment, whether that be in simulation or the real world.
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
def main():
|
||||
device = torch.device("cuda")
|
||||
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = "lerobot/diffusion_pusht"
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
# Set up the dataset.
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load full dataset
|
||||
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
||||
# - Calculate train and val subsets
|
||||
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
||||
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
||||
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
|
||||
# - Get first frame index of the validation set
|
||||
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
||||
# - Load frames subset belonging to validation set using the `split` argument.
|
||||
# It utilizes the `datasets` library's syntax for slicing datasets.
|
||||
# For more information on the Slice API, please see:
|
||||
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
# Set up the dataset.
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# Create dataloader for evaluation.
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=False,
|
||||
)
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load dataset metadata
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
# - Calculate train and val episodes
|
||||
total_episodes = dataset_metadata.total_episodes
|
||||
episodes = list(range(dataset_metadata.total_episodes))
|
||||
num_train_episodes = math.floor(total_episodes * 90 / 100)
|
||||
train_episodes = episodes[:num_train_episodes]
|
||||
val_episodes = episodes[num_train_episodes:]
|
||||
print(f"Number of episodes in full dataset: {total_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
|
||||
# - Load train an val datasets
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
|
||||
# Run validation loop.
|
||||
loss_cumsum = 0
|
||||
n_examples_evaluated = 0
|
||||
for batch in val_dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
# Create dataloader for evaluation.
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
loss_cumsum += output_dict["loss"].item()
|
||||
n_examples_evaluated += batch["index"].shape[0]
|
||||
# Run validation loop.
|
||||
loss_cumsum = 0
|
||||
n_examples_evaluated = 0
|
||||
for batch in val_dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
|
||||
# Calculate the average loss over the validation set.
|
||||
average_loss = loss_cumsum / n_examples_evaluated
|
||||
loss_cumsum += loss.item()
|
||||
n_examples_evaluated += batch["index"].shape[0]
|
||||
|
||||
print(f"Average loss on validation set: {average_loss:.4f}")
|
||||
# Calculate the average loss over the validation set.
|
||||
average_loss = loss_cumsum / n_examples_evaluated
|
||||
|
||||
print(f"Average loss on validation set: {average_loss:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
243
examples/port_datasets/pusht_zarr.py
Normal file
243
examples/port_datasets/pusht_zarr.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
|
||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
||||
PUSHT_FEATURES = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": {
|
||||
"axes": ["x", "y"],
|
||||
},
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": {
|
||||
"axes": ["x", "y"],
|
||||
},
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"next.success": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"observation.environment_state": {
|
||||
"dtype": "float32",
|
||||
"shape": (16,),
|
||||
"names": [
|
||||
"keypoints",
|
||||
],
|
||||
},
|
||||
"observation.image": {
|
||||
"dtype": None,
|
||||
"shape": (3, 96, 96),
|
||||
"names": [
|
||||
"channels",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_features(mode: str) -> dict:
|
||||
features = PUSHT_FEATURES
|
||||
if mode == "keypoints":
|
||||
features.pop("observation.image")
|
||||
else:
|
||||
features.pop("observation.environment_state")
|
||||
features["observation.image"]["dtype"] = mode
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def load_raw_dataset(zarr_path: Path):
|
||||
try:
|
||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
return zarr_data
|
||||
|
||||
|
||||
def calculate_coverage(zarr_data):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
block_pos = zarr_data["state"][:, 2:4]
|
||||
block_angle = zarr_data["state"][:, 4]
|
||||
|
||||
num_frames = len(block_pos)
|
||||
|
||||
coverage = np.zeros((num_frames,), dtype=np.float32)
|
||||
# 8 keypoints with 2 coords each
|
||||
keypoints = np.zeros((num_frames, 16), dtype=np.float32)
|
||||
|
||||
# Set x, y, theta (in radians)
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage[i] = intersection_area / goal_area
|
||||
keypoints[i] = PushTEnv.get_keypoints(block_shapes).flatten()
|
||||
|
||||
return coverage, keypoints
|
||||
|
||||
|
||||
def calculate_success(coverage: float, success_threshold: float):
|
||||
return coverage > success_threshold
|
||||
|
||||
|
||||
def calculate_reward(coverage: float, success_threshold: float):
|
||||
return np.clip(coverage / success_threshold, 0, 1)
|
||||
|
||||
|
||||
def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True):
|
||||
if mode not in ["video", "image", "keypoints"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
if (HF_LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||
|
||||
zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr")
|
||||
|
||||
env_state = zarr_data["state"][:]
|
||||
agent_pos = env_state[:, :2]
|
||||
|
||||
action = zarr_data["action"][:]
|
||||
image = zarr_data["img"] # (b, h, w, c)
|
||||
|
||||
if image.dtype == np.float32 and image.max() == np.float32(255):
|
||||
# HACK: images are loaded as float32 but they actually encode uint8 data
|
||||
image = image.astype(np.uint8)
|
||||
|
||||
episode_data_index = {
|
||||
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
||||
"to": zarr_data.meta["episode_ends"],
|
||||
}
|
||||
|
||||
# Calculate success and reward based on the overlapping area
|
||||
# of the T-object and the T-area.
|
||||
coverage, keypoints = calculate_coverage(zarr_data)
|
||||
success = calculate_success(coverage, success_threshold=0.95)
|
||||
reward = calculate_reward(coverage, success_threshold=0.95)
|
||||
|
||||
features = build_features(mode)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=10,
|
||||
robot_type="2d pointer",
|
||||
features=features,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
episodes = range(len(episode_data_index["from"]))
|
||||
for ep_idx in episodes:
|
||||
from_idx = episode_data_index["from"][ep_idx]
|
||||
to_idx = episode_data_index["to"][ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
i = from_idx + frame_idx
|
||||
idx = i + (frame_idx < num_frames - 1)
|
||||
frame = {
|
||||
"action": action[i],
|
||||
# Shift reward and success by +1 until the last item of the episode
|
||||
"next.reward": reward[idx : idx + 1],
|
||||
"next.success": success[idx : idx + 1],
|
||||
"task": PUSHT_TASK,
|
||||
}
|
||||
|
||||
frame["observation.state"] = agent_pos[i]
|
||||
|
||||
if mode == "keypoints":
|
||||
frame["observation.environment_state"] = keypoints[i]
|
||||
else:
|
||||
frame["observation.image"] = image[i]
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
||||
repository_id = "lerobot/pusht"
|
||||
|
||||
modes = ["video", "image", "keypoints"]
|
||||
# Uncomment if you want to try with a specific mode
|
||||
# modes = ["video"]
|
||||
# modes = ["image"]
|
||||
# modes = ["keypoints"]
|
||||
|
||||
data_dir = Path("data/lerobot-raw/pusht_raw")
|
||||
for available_mode in modes:
|
||||
if available_mode in ["image", "keypoints"]:
|
||||
repository_id += f"_{available_mode}"
|
||||
|
||||
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
||||
main(data_dir, repo_id=repository_id, mode=available_mode)
|
||||
|
||||
# Uncomment if you want to load the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id)
|
||||
# breakpoint()
|
||||
@@ -27,6 +27,9 @@ Example:
|
||||
print(lerobot.available_real_world_datasets)
|
||||
print(lerobot.available_policies)
|
||||
print(lerobot.available_policies_per_env)
|
||||
print(lerobot.available_robots)
|
||||
print(lerobot.available_cameras)
|
||||
print(lerobot.available_motors)
|
||||
```
|
||||
|
||||
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
||||
@@ -55,7 +58,6 @@ available_tasks_per_env = {
|
||||
],
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
@@ -83,23 +85,6 @@ available_datasets_per_env = {
|
||||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
"dora_aloha_real": [
|
||||
"lerobot/aloha_static_battery",
|
||||
"lerobot/aloha_static_candy",
|
||||
"lerobot/aloha_static_coffee",
|
||||
"lerobot/aloha_static_coffee_new",
|
||||
"lerobot/aloha_static_cups_open",
|
||||
"lerobot/aloha_static_fork_pick_up",
|
||||
"lerobot/aloha_static_pingpong_test",
|
||||
"lerobot/aloha_static_pro_pencil",
|
||||
"lerobot/aloha_static_screw_driver",
|
||||
"lerobot/aloha_static_tape",
|
||||
"lerobot/aloha_static_thread_velcro",
|
||||
"lerobot/aloha_static_towel",
|
||||
"lerobot/aloha_static_vinh_cup",
|
||||
"lerobot/aloha_static_vinh_cup_left",
|
||||
"lerobot/aloha_static_ziploc_slide",
|
||||
],
|
||||
}
|
||||
|
||||
available_real_world_datasets = [
|
||||
@@ -129,13 +114,60 @@ available_real_world_datasets = [
|
||||
"lerobot/unitreeh1_rearrange_objects",
|
||||
"lerobot/unitreeh1_two_robot_greeting",
|
||||
"lerobot/unitreeh1_warehouse",
|
||||
"lerobot/nyu_rot_dataset",
|
||||
"lerobot/utokyo_saytap",
|
||||
"lerobot/imperialcollege_sawyer_wrist_cam",
|
||||
"lerobot/utokyo_xarm_bimanual",
|
||||
"lerobot/tokyo_u_lsmo",
|
||||
"lerobot/utokyo_pr2_opening_fridge",
|
||||
"lerobot/cmu_franka_exploration_dataset",
|
||||
"lerobot/cmu_stretch",
|
||||
"lerobot/asu_table_top",
|
||||
"lerobot/utokyo_pr2_tabletop_manipulation",
|
||||
"lerobot/utokyo_xarm_pick_and_place",
|
||||
"lerobot/ucsd_kitchen_dataset",
|
||||
"lerobot/austin_buds_dataset",
|
||||
"lerobot/dlr_sara_grid_clamp",
|
||||
"lerobot/conq_hose_manipulation",
|
||||
"lerobot/columbia_cairlab_pusht_real",
|
||||
"lerobot/dlr_sara_pour",
|
||||
"lerobot/dlr_edan_shared_control",
|
||||
"lerobot/ucsd_pick_and_place_dataset",
|
||||
"lerobot/berkeley_cable_routing",
|
||||
"lerobot/nyu_franka_play_dataset",
|
||||
"lerobot/austin_sirius_dataset",
|
||||
"lerobot/cmu_play_fusion",
|
||||
"lerobot/berkeley_gnm_sac_son",
|
||||
"lerobot/nyu_door_opening_surprising_effectiveness",
|
||||
"lerobot/berkeley_fanuc_manipulation",
|
||||
"lerobot/jaco_play",
|
||||
"lerobot/viola",
|
||||
"lerobot/kaist_nonprehensile",
|
||||
"lerobot/berkeley_mvp",
|
||||
"lerobot/uiuc_d3field",
|
||||
"lerobot/berkeley_gnm_recon",
|
||||
"lerobot/austin_sailor_dataset",
|
||||
"lerobot/utaustin_mutex",
|
||||
"lerobot/roboturk",
|
||||
"lerobot/stanford_hydra_dataset",
|
||||
"lerobot/berkeley_autolab_ur5",
|
||||
"lerobot/stanford_robocook",
|
||||
"lerobot/toto",
|
||||
"lerobot/fmb",
|
||||
"lerobot/droid_100",
|
||||
"lerobot/berkeley_rpt",
|
||||
"lerobot/stanford_kuka_multimodal_dataset",
|
||||
"lerobot/iamlab_cmu_pickup_insert",
|
||||
"lerobot/taco_play",
|
||||
"lerobot/berkeley_gnm_cory_hall",
|
||||
"lerobot/usc_cloth_sim",
|
||||
]
|
||||
|
||||
available_datasets = list(
|
||||
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
||||
available_datasets = sorted(
|
||||
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/common/policies` by their class attribute: `name`.
|
||||
# lists all available policies from `lerobot/common/policies`
|
||||
available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
@@ -143,12 +175,34 @@ available_policies = [
|
||||
"vqbet",
|
||||
]
|
||||
|
||||
# lists all available robots from `lerobot/common/robot_devices/robots`
|
||||
available_robots = [
|
||||
"koch",
|
||||
"koch_bimanual",
|
||||
"aloha",
|
||||
"so100",
|
||||
"moss",
|
||||
]
|
||||
|
||||
# lists all available cameras from `lerobot/common/robot_devices/cameras`
|
||||
available_cameras = [
|
||||
"opencv",
|
||||
"intelrealsense",
|
||||
]
|
||||
|
||||
# lists all available motors from `lerobot/common/robot_devices/motors`
|
||||
available_motors = [
|
||||
"dynamixel",
|
||||
"feetech",
|
||||
]
|
||||
|
||||
# keys and values refer to yaml files
|
||||
available_policies_per_env = {
|
||||
"aloha": ["act"],
|
||||
"pusht": ["diffusion", "vqbet"],
|
||||
"xarm": ["tdmpc"],
|
||||
"dora_aloha_real": ["act_real"],
|
||||
"koch_real": ["act_koch_real"],
|
||||
"aloha_real": ["act_aloha_real"],
|
||||
}
|
||||
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
|
||||
45
lerobot/common/constants.py
Normal file
45
lerobot/common/constants.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# keys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HF_HOME
|
||||
|
||||
OBS_ENV = "observation.environment_state"
|
||||
OBS_ROBOT = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
|
||||
# files & directories
|
||||
CHECKPOINTS_DIR = "checkpoints"
|
||||
LAST_CHECKPOINT_LINK = "last"
|
||||
PRETRAINED_MODEL_DIR = "pretrained_model"
|
||||
TRAINING_STATE_DIR = "training_state"
|
||||
RNG_STATE = "rng_state.safetensors"
|
||||
TRAINING_STEP = "training_step.json"
|
||||
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
# cache dir
|
||||
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
|
||||
)
|
||||
68
lerobot/common/datasets/backward_compatibility.py
Normal file
68
lerobot/common/datasets/backward_compatibility.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import packaging.version
|
||||
|
||||
V2_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
||||
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
||||
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
||||
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
||||
sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V21_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
FUTURE_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is only available in {version} format.
|
||||
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
||||
"""
|
||||
|
||||
|
||||
class CompatibilityError(Exception): ...
|
||||
|
||||
|
||||
class BackwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ForwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
27
lerobot/common/datasets/card_template.md
Normal file
27
lerobot/common/datasets/card_template.md
Normal file
@@ -0,0 +1,27 @@
|
||||
---
|
||||
# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1
|
||||
# Doc / guide: https://huggingface.co/docs/hub/datasets-cards
|
||||
{{ card_data }}
|
||||
---
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
- **Homepage:** {{ url | default("[More Information Needed]", true)}}
|
||||
- **Paper:** {{ paper | default("[More Information Needed]", true)}}
|
||||
- **License:** {{ license | default("[More Information Needed]", true)}}
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
{{ dataset_structure | default("[More Information Needed]", true)}}
|
||||
|
||||
## Citation
|
||||
|
||||
**BibTeX:**
|
||||
|
||||
```bibtex
|
||||
{{ citation_bibtex | default("[More Information Needed]", true)}}
|
||||
```
|
||||
@@ -13,197 +13,164 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
import numpy as np
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Image
|
||||
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
from lerobot.common.datasets.utils import load_image_as_numpy
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
) -> int:
|
||||
"""Heuristic to estimate the number of samples based on dataset size.
|
||||
The power controls the sample growth relative to dataset size.
|
||||
Lower the power for less number of samples.
|
||||
|
||||
Note: We assume the images are in channel first format
|
||||
For default arguments, we have:
|
||||
- from 1 to ~500, num_samples=100
|
||||
- at 1000, num_samples=177
|
||||
- at 2000, num_samples=299
|
||||
- at 5000, num_samples=594
|
||||
- at 10000, num_samples=1000
|
||||
- at 20000, num_samples=1681
|
||||
"""
|
||||
if dataset_len < min_num_samples:
|
||||
min_num_samples = dataset_len
|
||||
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
for key, feats_type in dataset.features.items():
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
def sample_indices(data_len: int) -> list[int]:
|
||||
num_samples = estimate_num_samples(data_len)
|
||||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||
|
||||
if isinstance(feats_type, (VideoFrame, Image)):
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||||
_, height, width = img.shape
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
stats_patterns[key] = "b c -> c "
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
if max(width, height) < max_size_threshold:
|
||||
# no downsampling needed
|
||||
return img
|
||||
|
||||
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||||
return img[:, ::downsample_factor, ::downsample_factor]
|
||||
|
||||
|
||||
def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
sampled_indices = sample_indices(len(image_paths))
|
||||
|
||||
images = None
|
||||
for i, idx in enumerate(sampled_indices):
|
||||
path = image_paths[idx]
|
||||
# we load as uint8 to reduce memory usage
|
||||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||||
img = auto_downsample_height_width(img)
|
||||
|
||||
if images is None:
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||
|
||||
images[i] = img
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
||||
"std": np.std(array, axis=axis, keepdims=keepdims),
|
||||
"count": np.array([len(array)]),
|
||||
}
|
||||
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue # HACK: we should receive np.arrays of strings
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
keepdims = True
|
||||
else:
|
||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
|
||||
return stats_patterns
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
|
||||
|
||||
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
|
||||
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
|
||||
# for more info on why we need to set the same number of workers, see `load_from_videos`
|
||||
stats_patterns = get_stats_einops_patterns(dataset, num_workers)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
for key in stats_patterns:
|
||||
mean[key] = torch.tensor(0.0).float()
|
||||
std[key] = torch.tensor(0.0).float()
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(dataset, batch_size, seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
for key in stats_patterns:
|
||||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
for key in stats_patterns:
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = {}
|
||||
for key in stats_patterns:
|
||||
stats[key] = {
|
||||
"mean": mean[key],
|
||||
"std": std[key],
|
||||
"max": max[key],
|
||||
"min": min[key],
|
||||
}
|
||||
return stats
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
for i in range(len(stats_list)):
|
||||
for fkey in stats_list[i]:
|
||||
for k, v in stats_list[i][fkey].items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
|
||||
|
||||
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregates stats for a single feature."""
|
||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||
counts = np.stack([s["count"] for s in stats_ft_list])
|
||||
total_count = counts.sum(axis=0)
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
# Prepare weighted mean by matching number of dimensions
|
||||
while counts.ndim < means.ndim:
|
||||
counts = np.expand_dims(counts, axis=-1)
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
# Compute the weighted mean
|
||||
weighted_means = means * counts
|
||||
total_mean = weighted_means.sum(axis=0) / total_count
|
||||
|
||||
# Compute the variance using the parallel algorithm
|
||||
delta_means = means - total_mean
|
||||
weighted_variances = (variances + delta_means**2) * counts
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
return {
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"mean": total_mean,
|
||||
"std": np.sqrt(total_variance),
|
||||
"count": total_count,
|
||||
}
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
The final stats will have the union of all data keys from each of the stats dicts.
|
||||
|
||||
For instance:
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_mean = (mean of all data, weighted by counts)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
data_keys = set()
|
||||
for dataset in ls_datasets:
|
||||
data_keys.update(dataset.stats.keys())
|
||||
stats = {k: {} for k in data_keys}
|
||||
for data_key in data_keys:
|
||||
for stat_key in ["min", "max"]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["mean"] = sum(
|
||||
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.stats
|
||||
)
|
||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||
# the computation of the mean.
|
||||
# Given two sets of data where the statistics are known:
|
||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["std"] = torch.sqrt(
|
||||
sum(
|
||||
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
|
||||
* (d.num_samples / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.stats
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
_assert_type_and_shape(stats_list)
|
||||
|
||||
data_keys = {key for stats in stats_list for key in stats}
|
||||
aggregated_stats = {key: {} for key in data_keys}
|
||||
|
||||
for key in data_keys:
|
||||
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||||
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||||
|
||||
return aggregated_stats
|
||||
|
||||
@@ -13,105 +13,104 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.transforms import ImageTransforms
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
|
||||
}
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
"""Resolves delta_timestamps config key (in-place) by using `eval`.
|
||||
def resolve_delta_timestamps(
|
||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
||||
) -> dict[str, list] | None:
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
||||
|
||||
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
|
||||
the data type of its values).
|
||||
"""
|
||||
delta_timestamps = cfg.training.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
|
||||
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
|
||||
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""
|
||||
Args:
|
||||
cfg: A Hydra config as per the LeRobot config scheme.
|
||||
split: Select the data subset used to create an instance of LeRobotDataset.
|
||||
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
|
||||
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
|
||||
slicer in the hugging face datasets:
|
||||
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
|
||||
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
|
||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||
delta_timestamps against.
|
||||
|
||||
Returns:
|
||||
The LeRobotDataset.
|
||||
dict[str, list] | None: A dictionary of delta_timestamps, e.g.:
|
||||
{
|
||||
"observation.state": [-0.04, -0.02, 0]
|
||||
"observation.action": [-0.02, 0, 0.02]
|
||||
}
|
||||
returns `None` if the the resulting dict is empty.
|
||||
"""
|
||||
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
|
||||
raise ValueError(
|
||||
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
|
||||
"strings to load multiple datasets."
|
||||
delta_timestamps = {}
|
||||
for key in ds_meta.features:
|
||||
if key == "next.reward" and cfg.reward_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
delta_timestamps = None
|
||||
|
||||
return delta_timestamps
|
||||
|
||||
|
||||
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset | MultiLeRobotDataset
|
||||
"""
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
|
||||
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||
if cfg.env.name != "dora":
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||
else:
|
||||
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
|
||||
|
||||
for dataset_repo_id in dataset_repo_ids:
|
||||
if cfg.env.name not in dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
|
||||
f"environment ({cfg.env.name=})."
|
||||
)
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
# dataset = MultiLeRobotDataset(
|
||||
# cfg.dataset.repo_id,
|
||||
# # TODO(aliberts): add proper support for multi dataset
|
||||
# # delta_timestamps=delta_timestamps,
|
||||
# image_transforms=image_transforms,
|
||||
# video_backend=cfg.dataset.video_backend,
|
||||
# )
|
||||
# logging.info(
|
||||
# "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
# f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
# )
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
178
lerobot/common/datasets/image_writer.py
Normal file
178
lerobot/common/datasets/image_writer.py
Normal file
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
dataset = kwargs.get("dataset")
|
||||
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
image_writer.stop()
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
elif image_array.shape[-1] != 3:
|
||||
raise NotImplementedError(
|
||||
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||
)
|
||||
|
||||
if image_array.dtype != np.uint8:
|
||||
if range_check:
|
||||
max_ = image_array.max().item()
|
||||
min_ = image_array.min().item()
|
||||
if max_ > 1.0 or min_ < 0.0:
|
||||
raise ValueError(
|
||||
"The image data type is float, which requires values in the range [0.0, 1.0]. "
|
||||
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
|
||||
"provide a uint8 image with values in the range [0, 255]."
|
||||
)
|
||||
|
||||
image_array = (image_array * 255).astype(np.uint8)
|
||||
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
img = image_array_to_pil_image(image)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath)
|
||||
except Exception as e:
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
|
||||
|
||||
def worker_thread_loop(task_queue: queue.Queue):
|
||||
while True:
|
||||
item = task_queue.get()
|
||||
if item is None:
|
||||
task_queue.task_done()
|
||||
break
|
||||
image_array, fpath = item
|
||||
write_image(image_array, fpath)
|
||||
task_queue.task_done()
|
||||
|
||||
|
||||
def worker_process(task_queue: queue.Queue, num_threads: int):
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=worker_thread_loop, args=(task_queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
threads.append(t)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
|
||||
class AsyncImageWriter:
|
||||
"""
|
||||
This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it creates a threads pool of size `num_threads`.
|
||||
When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
|
||||
their own threads pool of size `num_threads`.
|
||||
|
||||
The optimal number of processes and threads depends on your computer capabilities.
|
||||
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
|
||||
def __init__(self, num_processes: int = 0, num_threads: int = 1):
|
||||
self.num_processes = num_processes
|
||||
self.num_threads = num_threads
|
||||
self.queue = None
|
||||
self.threads = []
|
||||
self.processes = []
|
||||
self._stopped = False
|
||||
|
||||
if num_threads <= 0 and num_processes <= 0:
|
||||
raise ValueError("Number of threads and processes must be greater than zero.")
|
||||
|
||||
if self.num_processes == 0:
|
||||
# Use threading
|
||||
self.queue = queue.Queue()
|
||||
for _ in range(self.num_threads):
|
||||
t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
self.threads.append(t)
|
||||
else:
|
||||
# Use multiprocessing
|
||||
self.queue = multiprocessing.JoinableQueue()
|
||||
for _ in range(self.num_processes):
|
||||
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
self.queue.put((image, fpath))
|
||||
|
||||
def wait_until_done(self):
|
||||
self.queue.join()
|
||||
|
||||
def stop(self):
|
||||
if self._stopped:
|
||||
return
|
||||
|
||||
if self.num_processes == 0:
|
||||
for _ in self.threads:
|
||||
self.queue.put(None)
|
||||
for t in self.threads:
|
||||
t.join()
|
||||
else:
|
||||
num_nones = self.num_processes * self.num_threads
|
||||
for _ in range(num_nones):
|
||||
self.queue.put(None)
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
self.queue.close()
|
||||
self.queue.join_thread()
|
||||
|
||||
self._stopped = True
|
||||
File diff suppressed because it is too large
Load Diff
@@ -187,7 +187,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_samples > 0:
|
||||
if self.num_frames > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
@@ -227,11 +227,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
def num_frames(self) -> int:
|
||||
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
return self.num_frames
|
||||
|
||||
def _item_to_tensors(self, item: dict) -> dict:
|
||||
item_ = {}
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
## Using / Updating `CODEBASE_VERSION` (for maintainers)
|
||||
|
||||
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
|
||||
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
|
||||
lerobot/common/datasets/lerobot_dataset.py) variable.
|
||||
|
||||
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
|
||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
|
||||
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
|
||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||
|
||||
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
|
||||
`info.json` metadata.
|
||||
|
||||
### Uploading a new dataset
|
||||
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
|
||||
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
|
||||
dataset with the current `CODEBASE_VERSION`.
|
||||
|
||||
### Updating an existing dataset
|
||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
||||
codebase won't be affected by your change and backward compatibility is maintained.
|
||||
|
||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
|
||||
api = HfApi()
|
||||
|
||||
for repo_id in available_datasets:
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if CODEBASE_VERSION in branches:
|
||||
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
|
||||
continue
|
||||
else:
|
||||
# Now create a branch named after the new version by branching out from "main"
|
||||
# which is expected to be the preceding version
|
||||
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
|
||||
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
|
||||
```
|
||||
@@ -53,7 +53,7 @@ def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compre
|
||||
# rechunk recompress
|
||||
group.move(name, tmp_key)
|
||||
old_arr = group[tmp_key]
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||
source=old_arr,
|
||||
dest=group,
|
||||
name=name,
|
||||
@@ -192,7 +192,7 @@ class ReplayBuffer:
|
||||
else:
|
||||
root = zarr.group(store=store)
|
||||
# copy without recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||
)
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
@@ -205,7 +205,7 @@ class ReplayBuffer:
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=src_store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
@@ -214,7 +214,7 @@ class ReplayBuffer:
|
||||
)
|
||||
else:
|
||||
# copy with recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||
source=value,
|
||||
dest=data_group,
|
||||
name=key,
|
||||
@@ -275,7 +275,7 @@ class ReplayBuffer:
|
||||
compressors = {}
|
||||
if self.backend == "zarr":
|
||||
# recompression free copy
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path="/meta",
|
||||
@@ -297,7 +297,7 @@ class ReplayBuffer:
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
|
||||
@@ -60,8 +60,8 @@ AVAILABLE_RAW_REPO_IDS = {
|
||||
"lerobot-raw/aloha_static_vinh_cup_left_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_vinh_cup_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_ziploc_slide_raw": "aloha_hdf5",
|
||||
"lerobot-raw/pusht_raw": "pusht_zarr",
|
||||
"lerobot-raw/umi_cup_in_the_wild_raw": "umi_zarr",
|
||||
"lerobot-raw/pusht_raw": "pusht_zarr",
|
||||
"lerobot-raw/unitreeh1_fold_clothes_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_rearrange_objects_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_two_robot_greeting_raw": "aloha_hdf5",
|
||||
@@ -70,6 +70,74 @@ AVAILABLE_RAW_REPO_IDS = {
|
||||
"lerobot-raw/xarm_lift_medium_replay_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_push_medium_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_push_medium_replay_raw": "xarm_pkl",
|
||||
"lerobot-raw/fractal20220817_data_raw": "openx_rlds.fractal20220817_data",
|
||||
"lerobot-raw/kuka_raw": "openx_rlds.kuka",
|
||||
"lerobot-raw/bridge_openx_raw": "openx_rlds.bridge_openx",
|
||||
"lerobot-raw/taco_play_raw": "openx_rlds.taco_play",
|
||||
"lerobot-raw/jaco_play_raw": "openx_rlds.jaco_play",
|
||||
"lerobot-raw/berkeley_cable_routing_raw": "openx_rlds.berkeley_cable_routing",
|
||||
"lerobot-raw/roboturk_raw": "openx_rlds.roboturk",
|
||||
"lerobot-raw/nyu_door_opening_surprising_effectiveness_raw": "openx_rlds.nyu_door_opening_surprising_effectiveness",
|
||||
"lerobot-raw/viola_raw": "openx_rlds.viola",
|
||||
"lerobot-raw/berkeley_autolab_ur5_raw": "openx_rlds.berkeley_autolab_ur5",
|
||||
"lerobot-raw/toto_raw": "openx_rlds.toto",
|
||||
"lerobot-raw/language_table_raw": "openx_rlds.language_table",
|
||||
"lerobot-raw/columbia_cairlab_pusht_real_raw": "openx_rlds.columbia_cairlab_pusht_real",
|
||||
"lerobot-raw/stanford_kuka_multimodal_dataset_raw": "openx_rlds.stanford_kuka_multimodal_dataset",
|
||||
"lerobot-raw/nyu_rot_dataset_raw": "openx_rlds.nyu_rot_dataset",
|
||||
"lerobot-raw/io_ai_tech_raw": "openx_rlds.io_ai_tech",
|
||||
"lerobot-raw/stanford_hydra_dataset_raw": "openx_rlds.stanford_hydra_dataset",
|
||||
"lerobot-raw/austin_buds_dataset_raw": "openx_rlds.austin_buds_dataset",
|
||||
"lerobot-raw/nyu_franka_play_dataset_raw": "openx_rlds.nyu_franka_play_dataset",
|
||||
"lerobot-raw/maniskill_dataset_raw": "openx_rlds.maniskill_dataset",
|
||||
"lerobot-raw/furniture_bench_dataset_raw": "openx_rlds.furniture_bench_dataset",
|
||||
"lerobot-raw/cmu_franka_exploration_dataset_raw": "openx_rlds.cmu_franka_exploration_dataset",
|
||||
"lerobot-raw/ucsd_kitchen_dataset_raw": "openx_rlds.ucsd_kitchen_dataset",
|
||||
"lerobot-raw/ucsd_pick_and_place_dataset_raw": "openx_rlds.ucsd_pick_and_place_dataset",
|
||||
"lerobot-raw/spoc_raw": "openx_rlds.spoc",
|
||||
"lerobot-raw/austin_sailor_dataset_raw": "openx_rlds.austin_sailor_dataset",
|
||||
"lerobot-raw/austin_sirius_dataset_raw": "openx_rlds.austin_sirius_dataset",
|
||||
"lerobot-raw/bc_z_raw": "openx_rlds.bc_z",
|
||||
"lerobot-raw/utokyo_pr2_opening_fridge_raw": "openx_rlds.utokyo_pr2_opening_fridge",
|
||||
"lerobot-raw/utokyo_pr2_tabletop_manipulation_raw": "openx_rlds.utokyo_pr2_tabletop_manipulation",
|
||||
"lerobot-raw/utokyo_xarm_pick_and_place_raw": "openx_rlds.utokyo_xarm_pick_and_place",
|
||||
"lerobot-raw/utokyo_xarm_bimanual_raw": "openx_rlds.utokyo_xarm_bimanual",
|
||||
"lerobot-raw/utokyo_saytap_raw": "openx_rlds.utokyo_saytap",
|
||||
"lerobot-raw/robo_net_raw": "openx_rlds.robo_net",
|
||||
"lerobot-raw/robo_set_raw": "openx_rlds.robo_set",
|
||||
"lerobot-raw/berkeley_mvp_raw": "openx_rlds.berkeley_mvp",
|
||||
"lerobot-raw/berkeley_rpt_raw": "openx_rlds.berkeley_rpt",
|
||||
"lerobot-raw/kaist_nonprehensile_raw": "openx_rlds.kaist_nonprehensile",
|
||||
"lerobot-raw/stanford_mask_vit_raw": "openx_rlds.stanford_mask_vit",
|
||||
"lerobot-raw/tokyo_u_lsmo_raw": "openx_rlds.tokyo_u_lsmo",
|
||||
"lerobot-raw/dlr_sara_pour_raw": "openx_rlds.dlr_sara_pour",
|
||||
"lerobot-raw/dlr_sara_grid_clamp_raw": "openx_rlds.dlr_sara_grid_clamp",
|
||||
"lerobot-raw/dlr_edan_shared_control_raw": "openx_rlds.dlr_edan_shared_control",
|
||||
"lerobot-raw/asu_table_top_raw": "openx_rlds.asu_table_top",
|
||||
"lerobot-raw/stanford_robocook_raw": "openx_rlds.stanford_robocook",
|
||||
"lerobot-raw/imperialcollege_sawyer_wrist_cam_raw": "openx_rlds.imperialcollege_sawyer_wrist_cam",
|
||||
"lerobot-raw/iamlab_cmu_pickup_insert_raw": "openx_rlds.iamlab_cmu_pickup_insert",
|
||||
"lerobot-raw/uiuc_d3field_raw": "openx_rlds.uiuc_d3field",
|
||||
"lerobot-raw/utaustin_mutex_raw": "openx_rlds.utaustin_mutex",
|
||||
"lerobot-raw/berkeley_fanuc_manipulation_raw": "openx_rlds.berkeley_fanuc_manipulation",
|
||||
"lerobot-raw/cmu_playing_with_food_raw": "openx_rlds.cmu_playing_with_food",
|
||||
"lerobot-raw/cmu_play_fusion_raw": "openx_rlds.cmu_play_fusion",
|
||||
"lerobot-raw/cmu_stretch_raw": "openx_rlds.cmu_stretch",
|
||||
"lerobot-raw/berkeley_gnm_recon_raw": "openx_rlds.berkeley_gnm_recon",
|
||||
"lerobot-raw/berkeley_gnm_cory_hall_raw": "openx_rlds.berkeley_gnm_cory_hall",
|
||||
"lerobot-raw/berkeley_gnm_sac_son_raw": "openx_rlds.berkeley_gnm_sac_son",
|
||||
"lerobot-raw/droid_raw": "openx_rlds.droid",
|
||||
"lerobot-raw/droid_100_raw": "openx_rlds.droid100",
|
||||
"lerobot-raw/fmb_raw": "openx_rlds.fmb",
|
||||
"lerobot-raw/dobbe_raw": "openx_rlds.dobbe",
|
||||
"lerobot-raw/usc_cloth_sim_raw": "openx_rlds.usc_cloth_sim",
|
||||
"lerobot-raw/plex_robosuite_raw": "openx_rlds.plex_robosuite",
|
||||
"lerobot-raw/conq_hose_manipulation_raw": "openx_rlds.conq_hose_manipulation",
|
||||
"lerobot-raw/vima_raw": "openx_rlds.vima",
|
||||
"lerobot-raw/robot_vqa_raw": "openx_rlds.robot_vqa",
|
||||
"lerobot-raw/mimic_play_raw": "openx_rlds.mimic_play",
|
||||
"lerobot-raw/tidybot_raw": "openx_rlds.tidybot",
|
||||
"lerobot-raw/eth_agent_affordances_raw": "openx_rlds.eth_agent_affordances",
|
||||
}
|
||||
|
||||
|
||||
@@ -84,7 +152,7 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
# Send warning if raw_dir isn't well formated
|
||||
# Send warning if raw_dir isn't well formatted
|
||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
||||
@@ -94,9 +162,9 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||
)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
logging.info("Start downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
logging.info("Finish downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||
|
||||
|
||||
def download_all_raw_datasets(data_dir: Path | None = None):
|
||||
@@ -110,7 +178,7 @@ def download_all_raw_datasets(data_dir: Path | None = None):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=f"""A script to download raw datasets from Hugging Face hub to a local directory. Here is a
|
||||
non exhaustive list of available repositories to use in `--repo-id`: {AVAILABLE_RAW_REPO_IDS}""",
|
||||
non exhaustive list of available repositories to use in `--repo-id`: {list(AVAILABLE_RAW_REPO_IDS.keys())}""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
@@ -30,12 +30,12 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -72,7 +72,7 @@ def check_format(raw_dir) -> bool:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
||||
else:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
_, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ def load_from_raw(
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
velocity = None
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
if "/observations/effort" in ep:
|
||||
|
||||
@@ -24,8 +24,11 @@ from datasets import Dataset, Features, Image, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
)
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
@@ -93,6 +96,7 @@ def from_raw_to_lerobot_format(
|
||||
if fps is None:
|
||||
fps = 30
|
||||
|
||||
# TODO(Steven): Is this meant to call cam_png_format.load_from_raw?
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
|
||||
@@ -26,8 +26,8 @@ import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
@@ -42,7 +42,9 @@ def check_format(raw_dir) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def load_from_raw(
|
||||
raw_dir: Path, videos_dir: Path, fps: int, _video: bool, _episodes: list[int] | None = None
|
||||
):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
@@ -68,11 +70,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
modality_df,
|
||||
on="timestamp_utc",
|
||||
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
||||
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# matching timestamps that are too far apart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
||||
# are too far appart.
|
||||
# are too far apart.
|
||||
direction="nearest",
|
||||
tolerance=pd.Timedelta(f"{1/fps} seconds"),
|
||||
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
|
||||
)
|
||||
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
|
||||
df = df[df["episode_index"] != -1]
|
||||
@@ -126,7 +128,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formated
|
||||
# sanity check the video paths are well formatted
|
||||
for key in df:
|
||||
if "observation.images." not in key:
|
||||
continue
|
||||
@@ -143,7 +145,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
||||
|
||||
# sanity check the video path is well formated
|
||||
# sanity check the video path is well formatted
|
||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video file not found in {video_path}")
|
||||
|
||||
312
lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py
Normal file
312
lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py
Normal file
@@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
For all datasets in the RLDS format.
|
||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||
|
||||
NOTE: You need to install tensorflow and tensorflow_datasets before running this script.
|
||||
|
||||
Example:
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir /path/to/data/bridge_dataset/1.0.0/ \
|
||||
--repo-id your_hub/sampled_bridge_data_v2 \
|
||||
--raw-format rlds \
|
||||
--episodes 3 4 5 8 9
|
||||
|
||||
Exact dataset fps defined in openx/config.py, obtained from:
|
||||
https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
np.set_printoptions(precision=2)
|
||||
|
||||
|
||||
def tf_to_torch(data):
|
||||
return torch.from_numpy(data.numpy())
|
||||
|
||||
|
||||
def tf_img_convert(img):
|
||||
if img.dtype == tf.string:
|
||||
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
|
||||
elif img.dtype != tf.uint8:
|
||||
raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
|
||||
return img.numpy()
|
||||
|
||||
|
||||
def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
|
||||
"""
|
||||
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
|
||||
entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
|
||||
trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
|
||||
|
||||
NOTE: adapted from DLimp library https://github.com/kvablack/dlimp/
|
||||
"""
|
||||
steps = traj.pop("steps")
|
||||
|
||||
traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
|
||||
|
||||
# broadcast metadata to the length of the trajectory
|
||||
metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
|
||||
|
||||
# put steps back in
|
||||
assert "traj_metadata" not in steps
|
||||
traj = {**steps, "traj_metadata": metadata}
|
||||
|
||||
assert "_len" not in traj
|
||||
assert "_traj_index" not in traj
|
||||
assert "_frame_index" not in traj
|
||||
traj["_len"] = tf.repeat(traj_len, traj_len)
|
||||
traj["_traj_index"] = tf.repeat(i, traj_len)
|
||||
traj["_frame_index"] = tf.range(traj_len)
|
||||
|
||||
return traj
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
raw_dir (Path): _description_
|
||||
videos_dir (Path): _description_
|
||||
fps (int): _description_
|
||||
video (bool): _description_
|
||||
episodes (list[int] | None, optional): _description_. Defaults to None.
|
||||
"""
|
||||
ds_builder = tfds.builder_from_directory(str(raw_dir))
|
||||
dataset = ds_builder.as_dataset(
|
||||
split="all",
|
||||
decoders={"steps": tfds.decode.SkipDecoding()},
|
||||
)
|
||||
|
||||
dataset_info = ds_builder.info
|
||||
print("dataset_info: ", dataset_info)
|
||||
|
||||
ds_length = len(dataset)
|
||||
dataset = dataset.take(ds_length)
|
||||
# "flatten" the dataset as such we can apply trajectory level map() easily
|
||||
# each [obs][key] has a shape of (frame_size, ...)
|
||||
dataset = dataset.enumerate().map(_broadcast_metadata_rlds)
|
||||
|
||||
# we will apply the standardization transform if the dataset_name is provided
|
||||
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
|
||||
# search for 'image' keys in the observations
|
||||
image_keys = []
|
||||
state_keys = []
|
||||
observation_info = dataset_info.features["steps"]["observation"]
|
||||
for key in observation_info:
|
||||
# check whether the key is for an image or a vector observation
|
||||
if len(observation_info[key].shape) == 3:
|
||||
# only adding uint8 images discards depth images
|
||||
if observation_info[key].dtype == tf.uint8:
|
||||
image_keys.append(key)
|
||||
else:
|
||||
state_keys.append(key)
|
||||
|
||||
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
||||
|
||||
print(" - image_keys: ", image_keys)
|
||||
print(" - lang_key: ", lang_key)
|
||||
|
||||
it = iter(dataset)
|
||||
|
||||
ep_dicts = []
|
||||
# Init temp path to save ep_dicts in case of crash
|
||||
tmp_ep_dicts_dir = videos_dir.parent.joinpath("ep_dicts")
|
||||
tmp_ep_dicts_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# check if ep_dicts have already been saved in /tmp
|
||||
starting_ep_idx = 0
|
||||
saved_ep_dicts = [ep.__str__() for ep in tmp_ep_dicts_dir.iterdir()]
|
||||
if len(saved_ep_dicts) > 0:
|
||||
saved_ep_dicts.sort()
|
||||
# get last ep_idx number
|
||||
starting_ep_idx = int(saved_ep_dicts[-1][-13:-3]) + 1
|
||||
for i in range(starting_ep_idx):
|
||||
episode = next(it)
|
||||
ep_dicts.append(torch.load(saved_ep_dicts[i]))
|
||||
|
||||
# if we user specified episodes, skip the ones not in the list
|
||||
if episodes is not None:
|
||||
if ds_length == 0:
|
||||
raise ValueError("No episodes found.")
|
||||
# convert episodes index to sorted list
|
||||
episodes = sorted(episodes)
|
||||
|
||||
for ep_idx in tqdm.tqdm(range(starting_ep_idx, ds_length)):
|
||||
episode = next(it)
|
||||
|
||||
# if user specified episodes, skip the ones not in the list
|
||||
if episodes is not None:
|
||||
if len(episodes) == 0:
|
||||
break
|
||||
if ep_idx == episodes[0]:
|
||||
# process this episode
|
||||
print(" selecting episode idx: ", ep_idx)
|
||||
episodes.pop(0)
|
||||
else:
|
||||
continue # skip
|
||||
|
||||
num_frames = episode["action"].shape[0]
|
||||
|
||||
ep_dict = {}
|
||||
for key in state_keys:
|
||||
ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key])
|
||||
|
||||
ep_dict["action"] = tf_to_torch(episode["action"])
|
||||
ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float()
|
||||
ep_dict["next.done"] = tf_to_torch(episode["is_last"])
|
||||
ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"])
|
||||
ep_dict["is_first"] = tf_to_torch(episode["is_first"])
|
||||
ep_dict["discount"] = tf_to_torch(episode["discount"])
|
||||
|
||||
# If lang_key is present, convert the entire tensor at once
|
||||
if lang_key is not None:
|
||||
ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
|
||||
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
|
||||
image_array_dict = {key: [] for key in image_keys}
|
||||
|
||||
for im_key in image_keys:
|
||||
imgs = episode["observation"][im_key]
|
||||
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
|
||||
|
||||
# loop through all cameras
|
||||
for im_key in image_keys:
|
||||
img_key = f"observation.images.{im_key}"
|
||||
imgs_array = image_array_dict[im_key]
|
||||
imgs_array = np.array(imgs_array)
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
path_ep_dict = tmp_ep_dicts_dir.joinpath(
|
||||
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
|
||||
)
|
||||
torch.save(ep_dict, path_ep_dict)
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
for key in data_dict:
|
||||
# check if vector state obs
|
||||
if key.startswith("observation.") and "observation.images." not in key:
|
||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||
# check if image obs
|
||||
elif "observation.images." in key:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
if "language_instruction" in data_dict:
|
||||
features["language_instruction"] = Value(dtype="string", id=None)
|
||||
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
||||
features["is_terminal"] = Value(dtype="bool", id=None)
|
||||
features["is_first"] = Value(dtype="bool", id=None)
|
||||
features["discount"] = Value(dtype="float32", id=None)
|
||||
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.reward"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -27,12 +27,12 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
@@ -28,12 +28,12 @@ from PIL import Image as PILImage
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
import inspect
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import numpy
|
||||
import PIL
|
||||
import torch
|
||||
@@ -53,7 +55,7 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
_ = [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
@@ -72,3 +74,57 @@ def check_repo_id(repo_id: str) -> None:
|
||||
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
||||
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
Parameters:
|
||||
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
||||
|
||||
Returns:
|
||||
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
||||
- "from": A tensor containing the starting index of each episode.
|
||||
- "to": A tensor containing the ending index of each episode.
|
||||
"""
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
current_episode = None
|
||||
# The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
# For instance, the following is a valid episode_index:
|
||||
# [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
#
|
||||
# Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
# ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
# {
|
||||
# "from": [0, 3, 7],
|
||||
# "to": [3, 7, 12]
|
||||
# }
|
||||
if len(hf_dataset) == 0:
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([]),
|
||||
"to": torch.tensor([]),
|
||||
}
|
||||
return episode_data_index
|
||||
idx = None
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
# We encountered a new episode, so we append its starting location to the "from" list
|
||||
episode_data_index["from"].append(idx)
|
||||
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
||||
if current_episode is not None:
|
||||
episode_data_index["to"].append(idx)
|
||||
# Let's keep track of the current episode index
|
||||
current_episode = episode_idx
|
||||
else:
|
||||
# We are still in the same episode, so there is nothing for us to do here
|
||||
pass
|
||||
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
||||
episode_data_index["to"].append(idx + 1)
|
||||
|
||||
for k in ["from", "to"]:
|
||||
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
||||
|
||||
return episode_data_index
|
||||
|
||||
@@ -27,12 +27,12 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
@@ -22,6 +23,7 @@ from torchvision.transforms.v2 import Transform
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
|
||||
# TODO(Steven): Missing transform() implementation
|
||||
class RandomSubsetApply(Transform):
|
||||
"""Apply a random subset of N transformations from a list of transformations.
|
||||
|
||||
@@ -65,6 +67,8 @@ class RandomSubsetApply(Transform):
|
||||
self.n_subset = n_subset
|
||||
self.random_order = random_order
|
||||
|
||||
self.selected_transforms = None
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
needs_unpacking = len(inputs) > 1
|
||||
|
||||
@@ -72,9 +76,9 @@ class RandomSubsetApply(Transform):
|
||||
if not self.random_order:
|
||||
selected_indices = selected_indices.sort().values
|
||||
|
||||
selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
self.selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
|
||||
for transform in selected_transforms:
|
||||
for transform in self.selected_transforms:
|
||||
outputs = transform(*inputs)
|
||||
inputs = outputs if needs_unpacking else (outputs,)
|
||||
|
||||
@@ -129,69 +133,119 @@ class SharpnessJitter(Transform):
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def _generate_value(self, left: float, right: float) -> float:
|
||||
return torch.empty(1).uniform_(left, right).item()
|
||||
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
||||
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
return {"sharpness_factor": sharpness_factor}
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
||||
sharpness_factor = params["sharpness_factor"]
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def get_image_transforms(
|
||||
brightness_weight: float = 1.0,
|
||||
brightness_min_max: tuple[float, float] | None = None,
|
||||
contrast_weight: float = 1.0,
|
||||
contrast_min_max: tuple[float, float] | None = None,
|
||||
saturation_weight: float = 1.0,
|
||||
saturation_min_max: tuple[float, float] | None = None,
|
||||
hue_weight: float = 1.0,
|
||||
hue_min_max: tuple[float, float] | None = None,
|
||||
sharpness_weight: float = 1.0,
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(
|
||||
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
|
||||
)
|
||||
if weight < 0.0:
|
||||
raise ValueError(
|
||||
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
|
||||
)
|
||||
@dataclass
|
||||
class ImageTransformConfig:
|
||||
"""
|
||||
For each transform, the following parameters are available:
|
||||
weight: This represents the multinomial probability (with no replacement)
|
||||
used for sampling the transform. If the sum of the weights is not 1,
|
||||
they will be normalized.
|
||||
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
|
||||
custom transform defined here.
|
||||
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
|
||||
(following uniform distribution) when it's applied.
|
||||
"""
|
||||
|
||||
check_value("brightness", brightness_weight, brightness_min_max)
|
||||
check_value("contrast", contrast_weight, contrast_min_max)
|
||||
check_value("saturation", saturation_weight, saturation_min_max)
|
||||
check_value("hue", hue_weight, hue_min_max)
|
||||
check_value("sharpness", sharpness_weight, sharpness_min_max)
|
||||
weight: float = 1.0
|
||||
type: str = "Identity"
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None and contrast_weight > 0.0:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None and saturation_weight > 0.0:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None and hue_weight > 0.0:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
n_subset = min(n_subset, max_num_transforms)
|
||||
@dataclass
|
||||
class ImageTransformsConfig:
|
||||
"""
|
||||
These transforms are all using standard torchvision.transforms.v2
|
||||
You can find out how these transformations affect images here:
|
||||
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||
We use a custom RandomSubsetApply container to sample them.
|
||||
"""
|
||||
|
||||
if n_subset == 0:
|
||||
return v2.Identity()
|
||||
# Set this flag to `true` to enable transforms during training
|
||||
enable: bool = False
|
||||
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
||||
# It's an integer in the interval [1, number_of_available_transforms].
|
||||
max_num_transforms: int = 3
|
||||
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||
# Set this to True to apply them in a random order.
|
||||
random_order: bool = False
|
||||
tfs: dict[str, ImageTransformConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.8, 1.2)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.8, 1.2)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 1.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (-0.05, 0.05)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 1.5)},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
else:
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
# TODO(Steven): Missing transform() implementation
|
||||
class ImageTransforms(Transform):
|
||||
"""A class to compose image transforms based on configuration."""
|
||||
|
||||
def __init__(self, cfg: ImageTransformsConfig) -> None:
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
|
||||
self.weights = []
|
||||
self.transforms = {}
|
||||
for tf_name, tf_cfg in cfg.tfs.items():
|
||||
if tf_cfg.weight <= 0.0:
|
||||
continue
|
||||
|
||||
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
|
||||
self.weights.append(tf_cfg.weight)
|
||||
|
||||
n_subset = min(len(self.transforms), cfg.max_num_transforms)
|
||||
if n_subset == 0 or not cfg.enable:
|
||||
self.tf = v2.Identity()
|
||||
else:
|
||||
self.tf = RandomSubsetApply(
|
||||
transforms=list(self.transforms.values()),
|
||||
p=self.weights,
|
||||
n_subset=n_subset,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
return self.tf(*inputs)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
884
lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py
Normal file
884
lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py
Normal file
@@ -0,0 +1,884 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
|
||||
|
||||
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
|
||||
lerobot/configs/robot/aloha.yaml before running this script.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
|
||||
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
# spellchecker:off
|
||||
ALOHA_MOBILE_INFO = {
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
"url": "https://mobile-aloha.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.02117",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{fu2024mobile,
|
||||
author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
|
||||
title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
|
||||
booktitle = {arXiv},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
}
|
||||
ALOHA_STATIC_INFO = {
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
"url": "https://tonyzhaozh.github.io/aloha/",
|
||||
"paper": "https://arxiv.org/abs/2304.13705",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{Zhao2023LearningFB,
|
||||
title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
|
||||
author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
|
||||
journal={RSS},
|
||||
year={2023},
|
||||
volume={abs/2304.13705},
|
||||
url={https://arxiv.org/abs/2304.13705}
|
||||
}""").lstrip(),
|
||||
}
|
||||
PUSHT_INFO = {
|
||||
"license": "mit",
|
||||
"url": "https://diffusion-policy.cs.columbia.edu/",
|
||||
"paper": "https://arxiv.org/abs/2303.04137v5",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{chi2024diffusionpolicy,
|
||||
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
|
||||
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
}
|
||||
XARM_INFO = {
|
||||
"license": "mit",
|
||||
"url": "https://www.nicklashansen.com/td-mpc/",
|
||||
"paper": "https://arxiv.org/abs/2203.04955",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{Hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
||||
booktitle={ICML},
|
||||
year={2022}
|
||||
}
|
||||
"""),
|
||||
}
|
||||
UNITREEH_INFO = {
|
||||
"license": "apache-2.0",
|
||||
}
|
||||
|
||||
DATASETS = {
|
||||
"aloha_mobile_cabinet": {
|
||||
"single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_chair": {
|
||||
"single_task": "Push the chairs in front of the desk to place them against it.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_elevator": {
|
||||
"single_task": "Take the elevator to the 1st floor.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_shrimp": {
|
||||
"single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_wash_pan": {
|
||||
"single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_wipe_wine": {
|
||||
"single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_static_battery": {
|
||||
"single_task": "Place the battery into the slot of the remote controller.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
|
||||
"aloha_static_coffee": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_coffee_new": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_cups_open": {
|
||||
"single_task": "Pick up the plastic cup and open its lid.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_fork_pick_up": {
|
||||
"single_task": "Pick up the fork and place it on the plate.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_pingpong_test": {
|
||||
"single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_pro_pencil": {
|
||||
"single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_screw_driver": {
|
||||
"single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_tape": {
|
||||
"single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_thread_velcro": {
|
||||
"single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_towel": {
|
||||
"single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup": {
|
||||
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup_left": {
|
||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_human_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_scripted": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_scripted_image": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_human": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_human_image": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
||||
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
|
||||
"unitreeh1_two_robot_greeting": {
|
||||
"single_task": "Greet the other robot with a high five.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"unitreeh1_warehouse": {
|
||||
"single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"umi_cup_in_the_wild": {
|
||||
"single_task": "Put the cup on the plate.",
|
||||
"license": "apache-2.0",
|
||||
},
|
||||
"asu_table_top": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{zhou2023modularity,
|
||||
title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
|
||||
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
|
||||
booktitle={Conference on Robot Learning},
|
||||
pages={1684--1695},
|
||||
year={2023},
|
||||
organization={PMLR}
|
||||
}
|
||||
@article{zhou2023learning,
|
||||
title={Learning modular language-conditioned robot policies through attention},
|
||||
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
|
||||
journal={Autonomous Robots},
|
||||
pages={1--21},
|
||||
year={2023},
|
||||
publisher={Springer}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_buds_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/BUDS-website/",
|
||||
"paper": "https://arxiv.org/abs/2109.13841",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{zhu2022bottom,
|
||||
title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
|
||||
author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
|
||||
journal={IEEE Robotics and Automation Letters},
|
||||
volume={7},
|
||||
number={2},
|
||||
pages={4126--4133},
|
||||
year={2022},
|
||||
publisher={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_sailor_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/sailor/",
|
||||
"paper": "https://arxiv.org/abs/2210.11435",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{nasiriany2022sailor,
|
||||
title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
|
||||
author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
|
||||
booktitle={Conference on Robot Learning (CoRL)},
|
||||
year={2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_sirius_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/sirius/",
|
||||
"paper": "https://arxiv.org/abs/2211.08416",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{liu2022robot,
|
||||
title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
|
||||
author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
|
||||
booktitle = {Robotics: Science and Systems (RSS)},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_autolab_ur5": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://sites.google.com/view/berkeley-ur5/home",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{BerkeleyUR5Website,
|
||||
title = {Berkeley {UR5} Demonstration Dataset},
|
||||
author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
|
||||
howpublished = {https://sites.google.com/view/berkeley-ur5/home},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_cable_routing": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://sites.google.com/view/cablerouting/home",
|
||||
"paper": "https://arxiv.org/abs/2307.08927",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{luo2023multistage,
|
||||
author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
|
||||
title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
|
||||
journal = {arXiv pre-print},
|
||||
year = {2023},
|
||||
url = {https://arxiv.org/abs/2307.08927},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_fanuc_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{fanuc_manipulation2023,
|
||||
title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
|
||||
author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_cory_hall": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/1709.10489",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{kahn2018self,
|
||||
title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
|
||||
author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
|
||||
booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
|
||||
pages={5129--5136},
|
||||
year={2018},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_recon": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/recon-robot",
|
||||
"paper": "https://arxiv.org/abs/2104.05859",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{shah2021rapid,
|
||||
title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
|
||||
author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
|
||||
booktitle={5th Annual Conference on Robot Learning },
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=d_SWJhyKfVw}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_sac_son": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/SACSoN-review",
|
||||
"paper": "https://arxiv.org/abs/2306.01874",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{hirose2023sacson,
|
||||
title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
|
||||
author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
|
||||
journal={arXiv preprint arXiv:2306.01874},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_mvp": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/2203.06173",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@InProceedings{Radosavovic2022,
|
||||
title = {Real-World Robot Learning with Masked Visual Pre-training},
|
||||
author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
|
||||
booktitle = {CoRL},
|
||||
year = {2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_rpt": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/2306.10007",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{Radosavovic2023,
|
||||
title={Robot Learning with Sensorimotor Pre-training},
|
||||
author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
|
||||
year={2023},
|
||||
journal={arXiv:2306.10007}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_franka_exploration_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://human-world-model.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2308.10901",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{mendonca2023structured,
|
||||
title={Structured World Models from Human Videos},
|
||||
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
|
||||
journal={RSS},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_play_fusion": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://play-fusion.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2312.04549",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{chen2023playfusion,
|
||||
title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
|
||||
author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
|
||||
booktitle={CoRL},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_stretch": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://robo-affordances.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2304.08488",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{bahl2023affordances,
|
||||
title={Affordances from Human Videos as a Versatile Representation for Robotics},
|
||||
author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
|
||||
booktitle={CVPR},
|
||||
year={2023}
|
||||
}
|
||||
@article{mendonca2023structured,
|
||||
title={Structured World Models from Human Videos},
|
||||
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
|
||||
journal={CoRL},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"columbia_cairlab_pusht_real": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://diffusion-policy.cs.columbia.edu/",
|
||||
"paper": "https://arxiv.org/abs/2303.04137v5",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{chi2023diffusionpolicy,
|
||||
title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
|
||||
booktitle={Proceedings of Robotics: Science and Systems (RSS)},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"conq_hose_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{ConqHoseManipData,
|
||||
author={Peter Mitrano and Dmitry Berenson},
|
||||
title={Conq Hose Manipulation Dataset, v1.15.0},
|
||||
year={2024},
|
||||
howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_edan_shared_control": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://ieeexplore.ieee.org/document/9341156",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{vogel_edan_2020,
|
||||
title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
|
||||
language = {en},
|
||||
booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
|
||||
author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
|
||||
year = {2020}
|
||||
}
|
||||
@inproceedings{quere_shared_2020,
|
||||
address = {Paris, France},
|
||||
title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
|
||||
language = {en},
|
||||
booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
|
||||
author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
|
||||
year = {2020},
|
||||
pages = {7},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_sara_grid_clamp": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://www.researchsquare.com/article/rs-3289569/v1",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{padalkar2023guided,
|
||||
title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
|
||||
author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
|
||||
journal={Research square preprint rs-3289569/v1},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_sara_pour": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{padalkar2023guiding,
|
||||
title={Guiding Reinforcement Learning with Shared Control Templates},
|
||||
author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
|
||||
booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
|
||||
year={2023},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"droid_100": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://droid-dataset.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2403.12945",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{khazatsky2024droid,
|
||||
title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
|
||||
author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"fmb": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://functional-manipulation-benchmark.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.08553",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{luo2024fmb,
|
||||
title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
|
||||
author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
|
||||
journal={arXiv preprint arXiv:2401.08553},
|
||||
year={2024}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"iamlab_cmu_pickup_insert": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://openreview.net/forum?id=WuBv9-IGDUA",
|
||||
"paper": "https://arxiv.org/abs/2401.14502",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{saxena2023multiresolution,
|
||||
title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
|
||||
author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
|
||||
booktitle={7th Annual Conference on Robot Learning},
|
||||
year={2023},
|
||||
url={https://openreview.net/forum?id=WuBv9-IGDUA}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"imperialcollege_sawyer_wrist_cam": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
},
|
||||
"jaco_play": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://github.com/clvrai/clvr_jaco_play_dataset",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@software{dass2023jacoplay,
|
||||
author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
|
||||
and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
|
||||
title = {CLVR Jaco Play Dataset},
|
||||
url = {https://github.com/clvrai/clvr_jaco_play_dataset},
|
||||
version = {1.0.0},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"kaist_nonprehensile": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{kimpre,
|
||||
title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
|
||||
author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
|
||||
booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
||||
year={2023},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_door_opening_surprising_effectiveness": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://jyopari.github.io/VINN/",
|
||||
"paper": "https://arxiv.org/abs/2112.01511",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{pari2021surprising,
|
||||
title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
|
||||
author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
|
||||
year={2021},
|
||||
eprint={2112.01511},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_franka_play_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://play-to-policy.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2210.10047",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{cui2022play,
|
||||
title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
|
||||
author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
|
||||
journal = {arXiv preprint arXiv:2210.10047},
|
||||
year = {2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_rot_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://rot-robot.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2206.15469",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{haldar2023watch,
|
||||
title={Watch and match: Supercharging imitation with regularized optimal transport},
|
||||
author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
|
||||
booktitle={Conference on Robot Learning},
|
||||
pages={32--43},
|
||||
year={2023},
|
||||
organization={PMLR}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"roboturk": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://roboturk.stanford.edu/dataset_real.html",
|
||||
"paper": "PAPER",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{mandlekar2019scaling,
|
||||
title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
|
||||
author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
|
||||
booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
||||
pages={1048--1055},
|
||||
year={2019},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_hydra_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/hydra-il-2023",
|
||||
"paper": "https://arxiv.org/abs/2306.17237",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{belkhale2023hydra,
|
||||
title={HYDRA: Hybrid Robot Actions for Imitation Learning},
|
||||
author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
|
||||
journal={arxiv},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_kuka_multimodal_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/visionandtouch",
|
||||
"paper": "https://arxiv.org/abs/1810.10191",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{lee2019icra,
|
||||
title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
|
||||
author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
|
||||
booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
year={2019},
|
||||
url={https://arxiv.org/abs/1810.10191}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_robocook": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://hshi74.github.io/robocook/",
|
||||
"paper": "https://arxiv.org/abs/2306.14447",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{shi2023robocook,
|
||||
title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
|
||||
author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
|
||||
journal={arXiv preprint arXiv:2306.14447},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"taco_play": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
|
||||
"paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{rosete2022tacorl,
|
||||
author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
|
||||
title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
|
||||
journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
|
||||
year = {2022}
|
||||
}
|
||||
@inproceedings{mees23hulc2,
|
||||
title={Grounding Language with Visual Affordances over Unstructured Data},
|
||||
author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
|
||||
booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
year={2023},
|
||||
address = {London, UK}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"tokyo_u_lsmo": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "URL",
|
||||
"paper": "https://arxiv.org/abs/2107.05842",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@Article{Osa22,
|
||||
author = {Takayuki Osa},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
|
||||
year = {2022},
|
||||
number = {3},
|
||||
pages = {291--311},
|
||||
volume = {41},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"toto": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://toto-benchmark.org/",
|
||||
"paper": "https://arxiv.org/abs/2306.00942",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{zhou2023train,
|
||||
author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
|
||||
booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
title={Train Offline, Test Online: A Real Robot Learning Benchmark},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"ucsd_kitchen_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@ARTICLE{ucsd_kitchens,
|
||||
author = {Ge Yan, Kris Wu, and Xiaolong Wang},
|
||||
title = {{ucsd kitchens Dataset}},
|
||||
year = {2023},
|
||||
month = {August}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"ucsd_pick_and_place_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://owmcorl.github.io/#",
|
||||
"paper": "https://arxiv.org/abs/2310.16029",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@preprint{Feng2023Finetuning,
|
||||
title={Finetuning Offline World Models in the Real World},
|
||||
author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"uiuc_d3field": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://robopil.github.io/d3fields/",
|
||||
"paper": "https://arxiv.org/abs/2309.16118",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{wang2023d3field,
|
||||
title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
|
||||
author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
|
||||
journal={arXiv preprint arXiv:},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"usc_cloth_sim": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://uscresl.github.io/dmfd/",
|
||||
"paper": "https://arxiv.org/abs/2207.10148",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{salhotra2022dmfd,
|
||||
author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
|
||||
journal={IEEE Robotics and Automation Letters},
|
||||
title={Learning Deformable Object Manipulation From Expert Demonstrations},
|
||||
year={2022},
|
||||
volume={7},
|
||||
number={4},
|
||||
pages={8775-8782},
|
||||
doi={10.1109/LRA.2022.3187843}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utaustin_mutex": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/MUTEX/",
|
||||
"paper": "https://arxiv.org/abs/2309.14320",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{shah2023mutex,
|
||||
title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
|
||||
author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
|
||||
booktitle={7th Annual Conference on Robot Learning},
|
||||
year={2023},
|
||||
url={https://openreview.net/forum?id=PwqiqaaEzJ}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_pr2_opening_fridge": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{oh2023pr2utokyodatasets,
|
||||
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
|
||||
title={X-Embodiment U-Tokyo PR2 Datasets},
|
||||
year={2023},
|
||||
url={https://github.com/ojh6404/rlds_dataset_builder},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_pr2_tabletop_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{oh2023pr2utokyodatasets,
|
||||
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
|
||||
title={X-Embodiment U-Tokyo PR2 Datasets},
|
||||
year={2023},
|
||||
url={https://github.com/ojh6404/rlds_dataset_builder},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_saytap": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://saytap.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2306.07580",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{saytap2023,
|
||||
author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
|
||||
Tatsuya Harada},
|
||||
title = {SayTap: Language to Quadrupedal Locomotion},
|
||||
eprint = {arXiv:2306.07580},
|
||||
url = {https://saytap.github.io},
|
||||
note = {https://saytap.github.io},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_xarm_bimanual": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{matsushima2023weblab,
|
||||
title={Weblab xArm Dataset},
|
||||
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_xarm_pick_and_place": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{matsushima2023weblab,
|
||||
title={Weblab xArm Dataset},
|
||||
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"viola": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/VIOLA/",
|
||||
"paper": "https://arxiv.org/abs/2210.11339",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{zhu2022viola,
|
||||
title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
|
||||
author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
|
||||
journal={6th Annual Conference on Robot Learning (CoRL)},
|
||||
year={2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
}
|
||||
# spellchecker:on
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
logfile = LOCAL_DIR / "conversion_log.txt"
|
||||
assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
|
||||
for num, (name, kwargs) in enumerate(DATASETS.items()):
|
||||
repo_id = f"lerobot/{name}"
|
||||
print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
||||
status = f"{repo_id}: success."
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
||||
667
lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py
Normal file
667
lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py
Normal file
@@ -0,0 +1,667 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
|
||||
|
||||
We support 3 different scenarios for these tasks (see instructions below):
|
||||
1. Single task dataset: all episodes of your dataset have the same single task.
|
||||
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
|
||||
one episode to the next.
|
||||
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
|
||||
|
||||
|
||||
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
|
||||
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
|
||||
recorded with. For now, only Aloha/Koch type robots are supported with this option.
|
||||
|
||||
|
||||
# 1. Single task dataset
|
||||
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
||||
'--single-task' option.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id lerobot/aloha_sim_insertion_human_image \
|
||||
--single-task "Insert the peg into the socket." \
|
||||
--robot-config lerobot/configs/robot/aloha.yaml \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id aliberts/koch_tutorial \
|
||||
--single-task "Pick the Lego block and drop it in the box on the right." \
|
||||
--robot-config lerobot/configs/robot/koch.yaml \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
|
||||
# 2. Single task episodes
|
||||
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
|
||||
|
||||
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
|
||||
this column's name with the '--tasks-col' arg.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||
--tasks-col "language_instruction" \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
|
||||
'--tasks-path' arg. This file should have the following structure where keys correspond to each
|
||||
episode_index in the dataset, and values are the language instruction for that episode.
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"0": "Do something",
|
||||
"1": "Do something else",
|
||||
"2": "Do something",
|
||||
"3": "Go there",
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
# 3. Multi task episodes
|
||||
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
|
||||
parquet file, and you must provide this column's name with the '--tasks-col' arg.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||
--tasks-col "language_instruction" \
|
||||
--local-dir data
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import filecmp
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_safe_version,
|
||||
load_json,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
write_jsonlines,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame, # noqa: F401
|
||||
get_image_pixel_channels,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_config
|
||||
|
||||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
|
||||
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
|
||||
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
||||
V1_INFO_PATH = "meta_data/info.json"
|
||||
V1_STATS_PATH = "meta_data/stats.safetensors"
|
||||
|
||||
|
||||
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
|
||||
if robot_cfg.type in ["aloha", "koch"]:
|
||||
state_names = [
|
||||
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
|
||||
for arm in robot_cfg.follower_arms
|
||||
for motor in robot_cfg.follower_arms[arm].motors
|
||||
]
|
||||
action_names = [
|
||||
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
|
||||
for arm in robot_cfg.leader_arms
|
||||
for motor in robot_cfg.leader_arms[arm].motors
|
||||
]
|
||||
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
|
||||
)
|
||||
|
||||
return {
|
||||
"robot_type": robot_cfg.type,
|
||||
"names": {
|
||||
"observation.state": state_names,
|
||||
"observation.effort": state_names,
|
||||
"action": action_names,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
safetensor_path = v1_dir / V1_STATS_PATH
|
||||
stats = load_file(safetensor_path)
|
||||
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
|
||||
json_path = v2_dir / STATS_PATH
|
||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(serialized_stats, f, indent=4)
|
||||
|
||||
# Sanity check
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
stats_json = json.load(f)
|
||||
|
||||
stats_json = flatten_dict(stats_json)
|
||||
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
|
||||
for key in stats:
|
||||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
||||
|
||||
def get_features_from_hf_dataset(
|
||||
dataset: Dataset, robot_config: RobotConfig | None = None
|
||||
) -> dict[str, list]:
|
||||
robot_config = parse_robot_config(robot_config)
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
dtype = ft.dtype
|
||||
shape = (1,)
|
||||
names = None
|
||||
elif isinstance(ft, datasets.Sequence):
|
||||
assert isinstance(ft.feature, datasets.Value)
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
motor_names = (
|
||||
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
)
|
||||
assert len(motor_names) == shape[0]
|
||||
names = {"motors": motor_names}
|
||||
elif isinstance(ft, datasets.Image):
|
||||
dtype = "image"
|
||||
image = dataset[0][key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
shape = (image.height, image.width, channels)
|
||||
names = ["height", "width", "channels"]
|
||||
elif ft._type == "VideoFrame":
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["height", "width", "channels"]
|
||||
else:
|
||||
raise NotImplementedError(f"Feature type {ft._type} not supported.")
|
||||
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
"shape": shape,
|
||||
"names": names,
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
tasks = list(set(tasks_by_episodes.values()))
|
||||
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
||||
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
||||
|
||||
features = dataset.features
|
||||
features["task_index"] = datasets.Value(dtype="int64")
|
||||
dataset = Dataset.from_pandas(df, features=features, split="train")
|
||||
return dataset, tasks
|
||||
|
||||
|
||||
def add_task_index_from_tasks_col(
|
||||
dataset: Dataset, tasks_col: str
|
||||
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
|
||||
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
||||
prefix_to_clean = "tf.Tensor(b'"
|
||||
suffix_to_clean = "', shape=(), dtype=string)"
|
||||
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
||||
|
||||
# Create task_index col
|
||||
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
||||
tasks = df[tasks_col].unique().tolist()
|
||||
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
||||
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
||||
|
||||
# Build the dataset back from df
|
||||
features = dataset.features
|
||||
features["task_index"] = datasets.Value(dtype="int64")
|
||||
dataset = Dataset.from_pandas(df, features=features, split="train")
|
||||
dataset = dataset.remove_columns(tasks_col)
|
||||
|
||||
return dataset, tasks, tasks_by_episode
|
||||
|
||||
|
||||
def split_parquet_by_episodes(
|
||||
dataset: Dataset,
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
output_dir: Path,
|
||||
) -> list:
|
||||
table = dataset.data.table
|
||||
episode_lengths = []
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx
|
||||
)
|
||||
pq.write_table(ep_table, output_file)
|
||||
|
||||
return episode_lengths
|
||||
|
||||
|
||||
def move_videos(
|
||||
repo_id: str,
|
||||
video_keys: list[str],
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
work_dir: Path,
|
||||
clean_gittatributes: Path,
|
||||
branch: str = "main",
|
||||
) -> None:
|
||||
"""
|
||||
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
|
||||
commands to fetch git lfs video files references to move them into subdirectories without having to
|
||||
actually download them.
|
||||
"""
|
||||
_lfs_clone(repo_id, work_dir, branch)
|
||||
|
||||
videos_moved = False
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||
if len(video_files) == 0:
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
||||
videos_moved = True # Videos have already been moved
|
||||
|
||||
assert len(video_files) == total_episodes * len(video_keys)
|
||||
|
||||
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
|
||||
|
||||
current_gittatributes = work_dir / ".gitattributes"
|
||||
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
|
||||
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
|
||||
|
||||
if lfs_untracked_videos:
|
||||
fix_lfs_video_files_tracking(work_dir, video_files)
|
||||
|
||||
if videos_moved:
|
||||
return
|
||||
|
||||
video_dirs = sorted(work_dir.glob("videos*/"))
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
for vid_key in video_keys:
|
||||
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key
|
||||
)
|
||||
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
target_path = DEFAULT_VIDEO_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
for v_dir in video_dirs:
|
||||
if (v_dir / video_file).is_file():
|
||||
video_path = v_dir / video_file
|
||||
break
|
||||
|
||||
video_path.rename(work_dir / target_path)
|
||||
|
||||
commit_message = "Move video files into chunk subdirectories"
|
||||
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
|
||||
"""
|
||||
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||
"""
|
||||
for i in range(0, len(lfs_untracked_videos), 100):
|
||||
files = lfs_untracked_videos[i : i + 100]
|
||||
try:
|
||||
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("git rm --cached ERROR:")
|
||||
print(e.stderr)
|
||||
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
|
||||
|
||||
commit_message = "Track video files with git lfs"
|
||||
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
|
||||
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
||||
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
|
||||
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||
subprocess.run(
|
||||
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
||||
check=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
||||
lfs_tracked_files = subprocess.run(
|
||||
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
||||
)
|
||||
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
for vid_key in video_keys
|
||||
]
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
)
|
||||
videos_info_dict = {}
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
||||
|
||||
return videos_info_dict
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
single_task: str | None = None,
|
||||
tasks_path: Path | None = None,
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: RobotConfig | None = None,
|
||||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
v1 = get_safe_version(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
v20_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
||||
)
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
branch = test_branch
|
||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
||||
|
||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||
features = get_features_from_hf_dataset(dataset, robot_config)
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
|
||||
if single_task and "language_instruction" in dataset.column_names:
|
||||
logging.warning(
|
||||
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
|
||||
)
|
||||
single_task = None
|
||||
tasks_col = "language_instruction"
|
||||
|
||||
# Episodes & chunks
|
||||
episode_indices = sorted(dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
assert episode_indices == list(range(total_episodes))
|
||||
total_videos = total_episodes * len(video_keys)
|
||||
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
|
||||
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
|
||||
total_chunks += 1
|
||||
|
||||
# Tasks
|
||||
if single_task:
|
||||
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_col:
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
# Videos
|
||||
if video_keys:
|
||||
assert metadata_v1.get("video", False)
|
||||
dataset = dataset.remove_columns(video_keys)
|
||||
clean_gitattr = Path(
|
||||
hub_api.hf_hub_download(
|
||||
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||
)
|
||||
).absolute()
|
||||
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||
move_videos(
|
||||
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.height"),
|
||||
videos_info[key].pop("video.width"),
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
if "encoding" in metadata_v1:
|
||||
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config.type
|
||||
repo_tags = [robot_type]
|
||||
else:
|
||||
robot_type = "unknown"
|
||||
repo_tags = None
|
||||
|
||||
# Episodes
|
||||
episodes = [
|
||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
|
||||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
"codebase_version": V20,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": len(dataset),
|
||||
"total_tasks": len(tasks),
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": metadata_v1["fps"],
|
||||
"splits": {"train": f"0:{total_episodes}"},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
||||
"features": features,
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="data",
|
||||
folder_path=v20_dir / "data",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="meta",
|
||||
folder_path=v20_dir / "meta",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if not test_branch:
|
||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
task_args = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--single-task",
|
||||
type=str,
|
||||
help="A short but accurate description of the single task performed in the dataset.",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--tasks-col",
|
||||
type=str,
|
||||
help="The name of the column containing language instructions",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--tasks-path",
|
||||
type=Path,
|
||||
help="The path to a .json file containing one language instruction for each episode_index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--license",
|
||||
type=str,
|
||||
default="apache-2.0",
|
||||
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.local_dir:
|
||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||
|
||||
robot_config = None
|
||||
if args.robot is not None:
|
||||
robot_config = make_robot_config(args.robot)
|
||||
|
||||
del args.robot
|
||||
|
||||
convert_dataset(**vars(args), robot_config=robot_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
87
lerobot/common/datasets/v21/_remove_language_instruction.py
Normal file
87
lerobot/common/datasets/v21/_remove_language_instruction.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import get_dataset_config_info
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import INFO_PATH, write_info
|
||||
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
hub_api = HfApi()
|
||||
|
||||
|
||||
def fix_dataset(repo_id: str) -> str:
|
||||
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
|
||||
return f"{repo_id}: skipped (not in {V20})."
|
||||
|
||||
dataset_info = get_dataset_config_info(repo_id, "default")
|
||||
with SuppressWarnings():
|
||||
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
||||
parquet_features = set(dataset_info.features)
|
||||
|
||||
diff_parquet_meta = parquet_features - meta_features
|
||||
diff_meta_parquet = meta_features - parquet_features
|
||||
|
||||
if diff_parquet_meta:
|
||||
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
||||
|
||||
if not diff_meta_parquet:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning("In info.json not in parquet: %s", meta_features - parquet_features)
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
commit_info = hub_api.upload_file(
|
||||
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
|
||||
path_in_repo=INFO_PATH,
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V20,
|
||||
commit_message="Remove 'language_instruction'",
|
||||
create_pr=True,
|
||||
)
|
||||
return f"{repo_id}: success - PR: {commit_info.pr_url}"
|
||||
|
||||
|
||||
def batch_fix():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "fix_features_v20.txt"
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
status = fix_dataset(repo_id)
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
logging.info(status)
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_fix()
|
||||
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
||||
hub_api = HfApi()
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
|
||||
status = f"{repo_id}: success (already in {V21})."
|
||||
else:
|
||||
convert_dataset(repo_id)
|
||||
status = f"{repo_id}: success."
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
||||
117
lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
Normal file
117
lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
|
||||
|
||||
class SuppressWarnings:
|
||||
def __init__(self):
|
||||
self.previous_level = None
|
||||
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logging.getLogger().setLevel(self.previous_level)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
99
lerobot/common/datasets/v21/convert_stats.py
Normal file
99
lerobot/common/datasets/v21/convert_stats.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
||||
return video_frames[ft_key].numpy()
|
||||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
||||
ep_stats = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if ft["dtype"] == "video":
|
||||
# We sample only for videos
|
||||
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
||||
else:
|
||||
ep_ft_data = np.array(ep_data[key])
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
|
||||
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
assert dataset.episodes is None
|
||||
print("Computing episodes stats")
|
||||
total_episodes = dataset.meta.total_episodes
|
||||
if num_workers > 0:
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
||||
for ep_idx in range(total_episodes)
|
||||
}
|
||||
for future in tqdm(as_completed(futures), total=total_episodes):
|
||||
future.result()
|
||||
else:
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
convert_episode_stats(dataset, ep_idx)
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
dataset: LeRobotDataset,
|
||||
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
||||
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
||||
):
|
||||
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
||||
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
||||
for key, ft in dataset.features.items():
|
||||
# These values might need some fine-tuning
|
||||
if ft["dtype"] == "video":
|
||||
# to account for image sub-sampling
|
||||
rtol, atol = video_rtol_atol
|
||||
else:
|
||||
rtol, atol = default_rtol_atol
|
||||
|
||||
for stat, val in agg_stats[key].items():
|
||||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(
|
||||
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||
)
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import warnings
|
||||
@@ -25,47 +26,11 @@ import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
|
||||
|
||||
def load_from_videos(
|
||||
item: dict[str, torch.Tensor],
|
||||
video_frame_keys: list[str],
|
||||
videos_dir: Path,
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
):
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
|
||||
This probably happens because a memory reference to the video loader is created in the main process and a
|
||||
subprocess fails to access it.
|
||||
"""
|
||||
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
|
||||
data_dir = videos_dir.parent
|
||||
|
||||
for key in video_frame_keys:
|
||||
if isinstance(item[key], list):
|
||||
# load multiple frames at once (expected when delta_timestamps is not None)
|
||||
timestamps = [frame["timestamp"] for frame in item[key]]
|
||||
paths = [frame["path"] for frame in item[key]]
|
||||
if len(set(paths)) > 1:
|
||||
raise NotImplementedError("All video paths are expected to be the same for now.")
|
||||
video_path = data_dir / paths[0]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
item[key] = frames
|
||||
else:
|
||||
# load one frame
|
||||
timestamps = [item[key]["timestamp"]]
|
||||
video_path = data_dir / item[key]["path"]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
item[key] = frames[0]
|
||||
|
||||
return item
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: str,
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
@@ -104,11 +69,11 @@ def decode_video_frames_torchvision(
|
||||
|
||||
# set the first and last requested timestamps
|
||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||
first_ts = timestamps[0]
|
||||
last_ts = timestamps[-1]
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
# access closest key frame of the first requested frame
|
||||
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||
reader.seek(first_ts, keyframes_only=keyframes_only)
|
||||
|
||||
@@ -118,7 +83,7 @@ def decode_video_frames_torchvision(
|
||||
for frame in reader:
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
logging.info("frame loaded at timestamp=%.4f", current_ts)
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
@@ -153,7 +118,7 @@ def decode_video_frames_torchvision(
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
logging.info("closest_ts=%s", closest_ts)
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
@@ -163,8 +128,8 @@ def decode_video_frames_torchvision(
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
@@ -247,3 +212,108 @@ with warnings.catch_warnings():
|
||||
)
|
||||
# to make VideoFrame available in HuggingFace `datasets`
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(
|
||||
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(
|
||||
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
video_stream_info = info["streams"][0]
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||
num, denom = map(int, r_frame_rate.split("/"))
|
||||
fps = num / denom
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_image_pixel_channels(image: Image):
|
||||
if image.mode == "L":
|
||||
return 1 # Grayscale
|
||||
elif image.mode == "LA":
|
||||
return 2 # Grayscale + Alpha
|
||||
elif image.mode == "RGB":
|
||||
return 3 # RGB
|
||||
elif image.mode == "RGBA":
|
||||
return 4 # RGBA
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,23 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
def test_visualize_dataset_html(tmpdir, repo_id):
|
||||
tmpdir = Path(tmpdir)
|
||||
visualize_dataset_html(
|
||||
repo_id,
|
||||
episodes=[0],
|
||||
output_dir=tmpdir,
|
||||
serve=False,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
||||
157
lerobot/common/envs/configs.py
Normal file
157
lerobot/common/envs/configs.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
task: str | None = None
|
||||
fps: int = 30
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def gym_kwargs(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaEnv(EnvConfig):
|
||||
task: str = "AlohaInsertion-v0"
|
||||
fps: int = 50
|
||||
episode_length: int = 400
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("pusht")
|
||||
@dataclass
|
||||
class PushtEnv(EnvConfig):
|
||||
task: str = "PushT-v0"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"environment_state": OBS_ENV,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("xarm")
|
||||
@dataclass
|
||||
class XarmEnv(EnvConfig):
|
||||
task: str = "XarmLift-v0"
|
||||
fps: int = 15
|
||||
episode_length: int = 200
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
@@ -16,43 +16,54 @@
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the evaluation config.
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
if env_type == "aloha":
|
||||
return AlohaEnv(**kwargs)
|
||||
elif env_type == "pusht":
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||
|
||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
cfg (EnvConfig): the config of the environment to instantiate.
|
||||
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
||||
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
||||
False.
|
||||
|
||||
Raises:
|
||||
ValueError: if n_envs < 1
|
||||
ModuleNotFoundError: If the requested env package is not installed
|
||||
|
||||
Returns:
|
||||
gym.vector.VectorEnv: The parallelized gym.env instance.
|
||||
"""
|
||||
if n_envs is not None and n_envs < 1:
|
||||
if n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
package_name = f"gym_{cfg.type}"
|
||||
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
|
||||
)
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
gym_kwgs = dict(cfg.env.get("gym", {}))
|
||||
|
||||
if cfg.env.get("episode_length"):
|
||||
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
]
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
@@ -18,8 +18,13 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.utils.utils import get_channel_first_image_shape
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
@@ -35,11 +40,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
@@ -60,3 +66,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return return_observations
|
||||
|
||||
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||
policy_features = {}
|
||||
for key, ft in env_cfg.features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
if len(ft.shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
|
||||
|
||||
shape = get_channel_first_image_shape(ft.shape)
|
||||
feature = PolicyFeature(type=ft.type, shape=shape)
|
||||
else:
|
||||
feature = ft
|
||||
|
||||
policy_key = env_cfg.features_map[key]
|
||||
policy_features[policy_key] = feature
|
||||
|
||||
return policy_features
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
|
||||
|
||||
# TODO(rcadene, alexander-soare): clean this file
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.name}",
|
||||
f"dataset:{cfg.dataset_repo_id}",
|
||||
f"env:{cfg.env.name}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
|
||||
# Get the WandB run ID.
|
||||
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
|
||||
if len(paths) != 1:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||
if match is None:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
wandb_run_id = match.groups(0)[0]
|
||||
return wandb_run_id
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb.
|
||||
|
||||
The logger creates the following directory structure:
|
||||
|
||||
provided_log_dir
|
||||
├── .hydra # hydra's configuration cache
|
||||
├── checkpoints
|
||||
│ ├── specific_checkpoint_name
|
||||
│ │ ├── pretrained_model # Hugging Face pretrained model directory
|
||||
│ │ │ ├── ...
|
||||
│ │ └── training_state.pth # optimizer, scheduler, and random states + training step
|
||||
| ├── another_specific_checkpoint_name
|
||||
│ │ ├── ...
|
||||
| ├── ...
|
||||
│ └── last # a softlink to the last logged checkpoint
|
||||
"""
|
||||
|
||||
pretrained_model_dir_name = "pretrained_model"
|
||||
training_state_file_name = "training_state.pth"
|
||||
|
||||
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
log_dir: The directory to save all logs and training outputs to.
|
||||
job_name: The WandB job name.
|
||||
"""
|
||||
self._cfg = cfg
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
|
||||
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
|
||||
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
|
||||
|
||||
# Set up WandB.
|
||||
self._group = cfg_to_group(cfg)
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=project,
|
||||
entity=entity,
|
||||
name=wandb_job_name,
|
||||
notes=cfg.get("wandb", {}).get("notes"),
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=log_dir,
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
# TODO(rcadene): try set to True
|
||||
save_code=False,
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
@classmethod
|
||||
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
|
||||
return Path(log_dir) / "checkpoints"
|
||||
|
||||
@classmethod
|
||||
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
|
||||
return cls.get_checkpoints_dir(log_dir) / "last"
|
||||
|
||||
@classmethod
|
||||
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""
|
||||
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
|
||||
be saved.
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
|
||||
Optionally also upload the model to WandB.
|
||||
"""
|
||||
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._cfg.wandb.disable_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
if self.last_checkpoint_dir.exists():
|
||||
os.remove(self.last_checkpoint_dir)
|
||||
|
||||
def save_training_state(
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
training_state = {
|
||||
"step": train_step,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
**get_global_random_state(),
|
||||
}
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
|
||||
def save_checkpont(
|
||||
self,
|
||||
train_step: int,
|
||||
policy: Policy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
identifier: str,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
wandb_artifact_name = (
|
||||
None
|
||||
if self._wandb is None
|
||||
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
|
||||
)
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
raise ValueError(
|
||||
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
|
||||
)
|
||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
15
lerobot/common/optim/__init__.py
Normal file
15
lerobot/common/optim/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .optimizers import OptimizerConfig as OptimizerConfig
|
||||
40
lerobot/common/optim/factory.py
Normal file
40
lerobot/common/optim/factory.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def make_optimizer_and_scheduler(
|
||||
cfg: TrainPipelineConfig, policy: PreTrainedPolicy
|
||||
) -> tuple[Optimizer, LRScheduler | None]:
|
||||
"""Generates the optimizer and scheduler based on configs.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs
|
||||
policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from.
|
||||
|
||||
Returns:
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
118
lerobot/common/optim/optimizers.py
Normal file
118
lerobot/common/optim/optimizers.py
Normal file
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
)
|
||||
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
lr: float
|
||||
weight_decay: float
|
||||
grad_clip_norm: float
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@classmethod
|
||||
def default_choice_name(cls) -> str | None:
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("adam")
|
||||
@dataclass
|
||||
class AdamConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
betas: tuple[float, float] = (0.9, 0.999)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.Adam(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("adamw")
|
||||
@dataclass
|
||||
class AdamWConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
betas: tuple[float, float] = (0.9, 0.999)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.AdamW(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("sgd")
|
||||
@dataclass
|
||||
class SGDConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
momentum: float = 0.0
|
||||
dampening: float = 0.0
|
||||
nesterov: bool = False
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
state = optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
||||
if "param_groups" in current_state_dict:
|
||||
param_groups = deserialize_json_into_object(
|
||||
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
|
||||
)
|
||||
loaded_state_dict["param_groups"] = param_groups
|
||||
|
||||
optimizer.load_state_dict(loaded_state_dict)
|
||||
return optimizer
|
||||
122
lerobot/common/optim/schedulers.py
Normal file
122
lerobot/common/optim/schedulers.py
Normal file
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.common.constants import SCHEDULER_STATE
|
||||
from lerobot.common.datasets.utils import write_json
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
num_warmup_steps: int
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("diffuser")
|
||||
@dataclass
|
||||
class DiffuserSchedulerConfig(LRSchedulerConfig):
|
||||
name: str = "cosine"
|
||||
num_warmup_steps: int | None = None
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||
return get_scheduler(**kwargs)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("vqbet")
|
||||
@dataclass
|
||||
class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
num_warmup_steps: int
|
||||
num_vqvae_training_steps: int
|
||||
num_cycles: float = 0.5
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step):
|
||||
if current_step < self.num_vqvae_training_steps:
|
||||
return float(1)
|
||||
else:
|
||||
adjusted_step = current_step - self.num_vqvae_training_steps
|
||||
if adjusted_step < self.num_warmup_steps:
|
||||
return float(adjusted_step) / float(max(1, self.num_warmup_steps))
|
||||
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
||||
max(1, num_training_steps - self.num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by Physical Intelligence to train Pi0"""
|
||||
|
||||
num_warmup_steps: int
|
||||
num_decay_steps: int
|
||||
peak_lr: float
|
||||
decay_lr: float
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
del num_training_steps
|
||||
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
return 1 / (self.num_warmup_steps + 1)
|
||||
frac = 1 - current_step / self.num_warmup_steps
|
||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
|
||||
if current_step < self.num_warmup_steps:
|
||||
return linear_warmup_schedule(current_step)
|
||||
|
||||
return cosine_decay_schedule(current_step)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
state_dict = scheduler.state_dict()
|
||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||
|
||||
|
||||
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
||||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
19
lerobot/common/policies/__init__.py
Normal file
19
lerobot/common/policies/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
@@ -15,9 +15,14 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("act")
|
||||
@dataclass
|
||||
class ACTConfig:
|
||||
class ACTConfig(PreTrainedConfig):
|
||||
"""Configuration class for the Action Chunking Transformers policy.
|
||||
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
@@ -59,7 +64,7 @@ class ACTConfig:
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||
convolution.
|
||||
@@ -90,28 +95,11 @@ class ACTConfig:
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "mean_std",
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -144,8 +132,15 @@ class ACTConfig:
|
||||
dropout: float = 0.1
|
||||
kl_weight: float = 10.0
|
||||
|
||||
# Training preset
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_lr_backbone: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
super().__post_init__()
|
||||
|
||||
# Input validation (not exhaustive).
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
@@ -164,8 +159,28 @@ class ACTConfig:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if (
|
||||
not any(k.startswith("observation.image") for k in self.input_shapes)
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -29,32 +29,27 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class ACTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "act"],
|
||||
):
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
"""
|
||||
|
||||
config_class = ACTConfig
|
||||
name = "act"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig | None = None,
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -64,30 +59,46 @@ class ACTPolicy(
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = ACTConfig()
|
||||
self.config: ACTConfig = config
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
|
||||
# Should we remove this and just `return self.parameters()`?
|
||||
return [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": self.config.optimizer_lr_backbone,
|
||||
},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
@@ -106,9 +117,11 @@ class ACTPolicy(
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -131,12 +144,14 @@ class ACTPolicy(
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
@@ -154,11 +169,11 @@ class ACTPolicy(
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
loss = l1_loss
|
||||
|
||||
return loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
@@ -207,6 +222,8 @@ class ACTTemporalEnsembler:
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.ensembled_actions = None
|
||||
self.ensembled_actions_count = None
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
@@ -288,31 +305,30 @@ class ACT(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
self.use_robot_state = "observation.state" in config.input_shapes
|
||||
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config)
|
||||
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(
|
||||
config.output_shapes["action"][0], config.dim_model
|
||||
self.config.action_feature.shape[0],
|
||||
config.dim_model,
|
||||
)
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
@@ -320,7 +336,7 @@ class ACT(nn.Module):
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
@@ -337,27 +353,27 @@ class ACT(nn.Module):
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
)
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
self.encoder_env_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.dim_model
|
||||
self.config.env_state_feature.shape[0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
n_1d_tokens = 1 # for the latent
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
n_1d_tokens += 1
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
@@ -365,7 +381,7 @@ class ACT(nn.Module):
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
|
||||
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
@@ -380,13 +396,13 @@ class ACT(nn.Module):
|
||||
|
||||
`batch` should have the following structure:
|
||||
{
|
||||
"observation.state" (optional): (B, state_dim) batch of robot states.
|
||||
[robot_state_feature] (optional): (B, state_dim) batch of robot states.
|
||||
|
||||
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
||||
[image_features]: (B, n_cameras, C, H, W) batch of images.
|
||||
AND/OR
|
||||
"observation.environment_state": (B, env_dim) batch of environment states.
|
||||
[env_state_feature]: (B, env_dim) batch of environment states.
|
||||
|
||||
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
[action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
}
|
||||
|
||||
Returns:
|
||||
@@ -395,9 +411,9 @@ class ACT(nn.Module):
|
||||
latent dimension.
|
||||
"""
|
||||
if self.config.use_vae and self.training:
|
||||
assert (
|
||||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
assert "action" in batch, (
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
batch_size = (
|
||||
batch["observation.images"]
|
||||
@@ -411,12 +427,12 @@ class ACT(nn.Module):
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
@@ -430,7 +446,7 @@ class ACT(nn.Module):
|
||||
# sequence depending whether we use the input states or not (cls and robot state)
|
||||
# False means not a padding token.
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.use_robot_state else 1),
|
||||
(batch_size, 2 if self.config.robot_state_feature else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
)
|
||||
@@ -463,16 +479,16 @@ class ACT(nn.Module):
|
||||
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.use_robot_state:
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
# Environment state token.
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
|
||||
# Camera observation features and positional embeddings.
|
||||
if self.use_images:
|
||||
if self.config.image_features:
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
|
||||
@@ -521,9 +537,11 @@ class ACT(nn.Module):
|
||||
class ACTEncoder(nn.Module):
|
||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||
|
||||
def __init__(self, config: ACTConfig):
|
||||
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
|
||||
self.is_vae_encoder = is_vae_encoder
|
||||
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
|
||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -16,9 +16,15 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("diffusion")
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
class DiffusionConfig(PreTrainedConfig):
|
||||
"""Configuration class for DiffusionPolicy.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
@@ -62,11 +68,12 @@ class DiffusionConfig:
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
@@ -92,7 +99,7 @@ class DiffusionConfig:
|
||||
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
|
||||
to False as the original Diffusion Policy implementation does the same.
|
||||
"""
|
||||
|
||||
@@ -101,26 +108,17 @@ class DiffusionConfig:
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -130,6 +128,7 @@ class DiffusionConfig:
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
use_separate_rgb_encoder_per_camera: bool = False
|
||||
# Unet.
|
||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||
kernel_size: int = 5
|
||||
@@ -152,39 +151,23 @@ class DiffusionConfig:
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-6
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
super().__post_init__()
|
||||
|
||||
# Input validation (not exhaustive).
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
|
||||
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if len(image_keys) > 0:
|
||||
if self.crop_shape is not None:
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
raise ValueError(
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
)
|
||||
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
raise ValueError(
|
||||
@@ -196,3 +179,59 @@ class DiffusionConfig:
|
||||
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
if self.horizon % downsampling_factor != 0:
|
||||
raise ValueError(
|
||||
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -31,35 +31,32 @@ import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
|
||||
class DiffusionPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "diffusion-policy"],
|
||||
):
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
"""
|
||||
|
||||
config_class = DiffusionConfig
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig | None = None,
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -69,18 +66,16 @@ class DiffusionPolicy(
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = DiffusionConfig()
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
@@ -88,20 +83,20 @@ class DiffusionPolicy(
|
||||
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.diffusion.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
@@ -127,9 +122,11 @@ class DiffusionPolicy(
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -146,15 +143,18 @@ class DiffusionPolicy(
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
@@ -170,23 +170,25 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
|
||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||
|
||||
|
||||
# TODO(Steven): Missing forward() implementation
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Build observation encoders (depending on which observations are provided).
|
||||
global_cond_dim = config.input_shapes["observation.state"][0]
|
||||
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
self._use_images = False
|
||||
self._use_env_state = False
|
||||
if num_images > 0:
|
||||
self._use_images = True
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||
global_cond_dim = self.config.robot_state_feature.shape[0]
|
||||
if self.config.image_features:
|
||||
num_images = len(self.config.image_features)
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||
self.rgb_encoder = nn.ModuleList(encoders)
|
||||
global_cond_dim += encoders[0].feature_dim * num_images
|
||||
else:
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if self.config.env_state_feature:
|
||||
global_cond_dim += self.config.env_state_feature.shape[0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
@@ -202,6 +204,7 @@ class DiffusionModel(nn.Module):
|
||||
)
|
||||
|
||||
if config.num_inference_steps is None:
|
||||
# TODO(Steven): Consider type check?
|
||||
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||
else:
|
||||
self.num_inference_steps = config.num_inference_steps
|
||||
@@ -215,7 +218,7 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
|
||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -237,22 +240,38 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
global_cond_feats = [batch["observation.state"]]
|
||||
# Extract image feature (first combine batch, sequence, and camera index dims).
|
||||
if self._use_images:
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
|
||||
global_cond_feats = [batch[OBS_ROBOT]]
|
||||
# Extract image features.
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||
]
|
||||
)
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self._use_env_state:
|
||||
global_cond_feats.append(batch["observation.environment_state"])
|
||||
if self.config.env_state_feature:
|
||||
global_cond_feats.append(batch[OBS_ENV])
|
||||
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
@@ -316,7 +335,7 @@ class DiffusionModel(nn.Module):
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
high=self.noise_scheduler.config.num_train_timesteps,
|
||||
high=self.noise_scheduler.config.num_train_timesteps, # TODO(Steven): Consider type check?
|
||||
size=(trajectory.shape[0],),
|
||||
device=trajectory.device,
|
||||
).long()
|
||||
@@ -422,7 +441,7 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
|
||||
class DiffusionRgbEncoder(nn.Module):
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
"""Encodes an RGB image into a 1D feature vector.
|
||||
|
||||
Includes the ability to normalize and crop the image first.
|
||||
"""
|
||||
@@ -461,19 +480,16 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
@@ -590,7 +606,7 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
|
||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||
# just reverse these.
|
||||
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
|
||||
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
|
||||
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
|
||||
)
|
||||
|
||||
@@ -645,7 +661,7 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
|
||||
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||
|
||||
@@ -13,99 +13,138 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
|
||||
import logging
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch import nn
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
|
||||
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
|
||||
logging.warning(
|
||||
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||
)
|
||||
|
||||
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
|
||||
# issues with mutable defaults. This filter changes all lists to tuples.
|
||||
def list_to_tuple(item):
|
||||
return tuple(item) if isinstance(item, list) else item
|
||||
|
||||
policy_cfg = policy_cfg_class(
|
||||
**{
|
||||
k: list_to_tuple(v)
|
||||
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
|
||||
if k in expected_kwargs
|
||||
}
|
||||
)
|
||||
return policy_cfg
|
||||
|
||||
|
||||
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
return TDMPCPolicy
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
return DiffusionPolicy, DiffusionConfig
|
||||
return DiffusionPolicy
|
||||
elif name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
return ACTPolicy, ACTConfig
|
||||
return ACTPolicy
|
||||
elif name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy, VQBeTConfig
|
||||
return VQBeTPolicy
|
||||
elif name == "pi0":
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
if policy_type == "tdmpc":
|
||||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
def make_policy(
|
||||
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
|
||||
) -> Policy:
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
This function exists because (for now) we need to parse features from either a dataset or an environment
|
||||
in order to properly dimension and instantiate a policy for that dataset or environment.
|
||||
|
||||
Args:
|
||||
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is
|
||||
provided, only `hydra_cfg.policy.name` is used while everything else is ignored.
|
||||
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a
|
||||
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
|
||||
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`.
|
||||
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must
|
||||
be provided when initializing a new policy, and must not be provided when loading a pretrained
|
||||
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||
be loaded with the weights from that path.
|
||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||
provided if ds_meta is not. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
|
||||
|
||||
Returns:
|
||||
PreTrainedPolicy: _description_
|
||||
"""
|
||||
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
raise ValueError(
|
||||
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
if bool(ds_meta) == bool(env_cfg):
|
||||
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
||||
|
||||
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
|
||||
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
|
||||
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
|
||||
# you want this op to be added in priority during the prototype phase of this feature, please comment on
|
||||
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
||||
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
||||
# slower than running natively on MPS.
|
||||
if cfg.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
policy_cls = get_policy_class(cfg.type)
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
if pretrained_policy_name_or_path is None:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
kwargs = {}
|
||||
if ds_meta is not None:
|
||||
features = dataset_to_policy_features(ds_meta.features)
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
else:
|
||||
if not cfg.pretrained_path:
|
||||
logging.warning(
|
||||
"You are instantiating a policy from scratch and its features are parsed from an environment "
|
||||
"rather than a dataset. Normalization modules inside the policy will have infinite values "
|
||||
"by default without stats from a dataset."
|
||||
)
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
|
||||
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
|
||||
# huggingface_hub should make it possible to avoid the hack:
|
||||
# https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(**kwargs)
|
||||
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
return policy
|
||||
|
||||
@@ -13,13 +13,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
@@ -34,12 +37,16 @@ def create_stats_buffers(
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape = tuple(shapes[key])
|
||||
assert isinstance(norm_mode, NormalizationMode)
|
||||
|
||||
if "image" in key:
|
||||
shape = tuple(ft.shape)
|
||||
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
@@ -52,7 +59,7 @@ def create_stats_buffers(
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if mode == "mean_std":
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
@@ -61,27 +68,39 @@ def create_stats_buffers(
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif mode == "min_max":
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
"min": nn.Parameter(min_norm, requires_grad=False),
|
||||
"max": nn.Parameter(max_norm, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
if stats is not None:
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[key]["mean"].clone()
|
||||
buffer["std"].data = stats[key]["std"].clone()
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[key]["min"].clone()
|
||||
buffer["max"].data = stats[key]["max"].clone()
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
@@ -99,8 +118,8 @@ class Normalize(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -122,10 +141,10 @@ class Normalize(nn.Module):
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@@ -133,26 +152,34 @@ class Normalize(nn.Module):
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_norm = buffer["min"]
|
||||
max_norm = buffer["max"]
|
||||
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
batch[key] = (batch[key] - min_norm) / (max_norm - min_norm + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
@@ -164,8 +191,8 @@ class Unnormalize(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -187,11 +214,11 @@ class Unnormalize(nn.Module):
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@@ -199,22 +226,29 @@ class Unnormalize(nn.Module):
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_norm = buffer["min"]
|
||||
max_norm = buffer["max"]
|
||||
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
batch[key] = batch[key] * (max_norm - min_norm) + min_norm
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
149
lerobot/common/policies/pi0/configuration_pi0.py
Normal file
149
lerobot/common/policies/pi0/configuration_pi0.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
class PI0Config(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Projector
|
||||
proj_width: int = 1024
|
||||
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
attention_implementation: str = "eager" # or fa2, flex
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = False
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2.5e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
# TODO: Add EMA
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||
# Input validation (not exhaustive).
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: implement value error
|
||||
# if not self.image_features and not self.env_state_feature:
|
||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
82
lerobot/common/policies/pi0/conversion_scripts/benchmark.py
Normal file
82
lerobot/common/policies/pi0/conversion_scripts/benchmark.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def main():
|
||||
device = "cuda"
|
||||
dataset_repo_id = "danaaubakirova/koch_test"
|
||||
# model_name = "pi0_base"
|
||||
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_torch_dir = "lerobot/pi0"
|
||||
|
||||
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
# To device
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, ds_meta=dataset.meta)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
warmup_iters = 10
|
||||
benchmark_iters = 30
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iters):
|
||||
torch.cuda.synchronize()
|
||||
policy.select_action(batch)
|
||||
policy.reset()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for _ in range(benchmark_iters):
|
||||
policy.select_action(batch)
|
||||
policy.reset()
|
||||
end_event.record()
|
||||
|
||||
# Synchronize and measure time
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time_ms = start_event.elapsed_time(end_event)
|
||||
|
||||
avg_time_per_iter = elapsed_time_ms / benchmark_iters
|
||||
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.inference_mode():
|
||||
main()
|
||||
@@ -0,0 +1,131 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
if tensor.dtype == torch.bool:
|
||||
tensor = tensor.float()
|
||||
print(f"Shape: {tensor.shape}")
|
||||
print(f"Mean: {tensor.mean().item()}")
|
||||
print(f"Std: {tensor.std().item()}")
|
||||
print(f"Min: {tensor.min().item()}")
|
||||
print(f"Max: {tensor.max().item()}")
|
||||
|
||||
|
||||
def main():
|
||||
num_motors = 14
|
||||
device = "cuda"
|
||||
# model_name = "pi0_aloha_towel"
|
||||
model_name = "pi0_aloha_sim"
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
dataset_repo_id = "lerobot/aloha_static_towel"
|
||||
else:
|
||||
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
|
||||
|
||||
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
|
||||
save_dir = Path(f"../openpi/data/{model_name}/save")
|
||||
|
||||
with open(save_dir / "example.pkl", "rb") as f:
|
||||
example = pickle.load(f)
|
||||
with open(save_dir / "outputs.pkl", "rb") as f:
|
||||
outputs = pickle.load(f)
|
||||
with open(save_dir / "noise.pkl", "rb") as f:
|
||||
noise = pickle.load(f)
|
||||
|
||||
with open(ckpt_jax_dir / "assets/norm_stats.json", encoding="utf-8") as f:
|
||||
norm_stats = json.load(f)
|
||||
|
||||
# Override stats
|
||||
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
del batch["observation.images.cam_low"]
|
||||
elif model_name == "pi0_aloha_sim":
|
||||
batch["observation.images.top"] = batch["observation.images.cam_high"]
|
||||
del batch["observation.images.cam_high"]
|
||||
|
||||
# Batchify
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].unsqueeze(0)
|
||||
elif isinstance(batch[key], str):
|
||||
batch[key] = [batch[key]]
|
||||
else:
|
||||
raise ValueError(f"{key}, {batch[key]}")
|
||||
|
||||
# To device
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
||||
|
||||
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
|
||||
|
||||
from lerobot.common import policies # noqa
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, dataset_meta)
|
||||
|
||||
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
||||
# loss_dict["loss"].backward()
|
||||
# print("losses")
|
||||
# display(loss_dict["losses_after_forward"])
|
||||
# print("pi_losses")
|
||||
# display(pi_losses)
|
||||
|
||||
actions = []
|
||||
for _ in range(50):
|
||||
action = policy.select_action(batch, noise=noise)
|
||||
actions.append(action)
|
||||
|
||||
actions = torch.stack(actions, dim=1)
|
||||
pi_actions = batch["action"]
|
||||
print("actions")
|
||||
display(actions)
|
||||
print()
|
||||
print("pi_actions")
|
||||
display(pi_actions)
|
||||
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
|
||||
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
|
||||
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,84 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import GemmaConfig, PaliGemmaConfig
|
||||
|
||||
|
||||
def get_paligemma_config(precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
}
|
||||
|
||||
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
|
||||
|
||||
image_size = 224 # image_sizes[variant]
|
||||
patch_size = 14
|
||||
num_image_tokens = (image_size**2) // (patch_size**2)
|
||||
|
||||
config["image_token_index"] = 257152
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"torch_dtype": precision,
|
||||
"hidden_size": 2048,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 16384,
|
||||
"is_encoder_decoder": False,
|
||||
}
|
||||
vision_config = {
|
||||
"torch_dtype": precision,
|
||||
"image_size": image_size,
|
||||
"patch_size": patch_size,
|
||||
"num_image_tokens": num_image_tokens,
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"num_hidden_layers": 27,
|
||||
"num_attention_heads": 16,
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"vision_use_head": False,
|
||||
}
|
||||
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
||||
return final_config
|
||||
|
||||
|
||||
def get_gemma_config(precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
}
|
||||
|
||||
config["image_token_index"] = 257152
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"torch_dtype": precision,
|
||||
"hidden_size": 1024,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 4096,
|
||||
"is_encoder_decoder": False,
|
||||
}
|
||||
final_config = GemmaConfig()
|
||||
final_config.update(text_config)
|
||||
return final_config
|
||||
@@ -0,0 +1,437 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Convert pi0 parameters from Jax to Pytorch
|
||||
|
||||
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
|
||||
and install the required libraries.
|
||||
|
||||
```bash
|
||||
cd ~/code/openpi
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Example downloading parameters:
|
||||
```bash
|
||||
python
|
||||
>>> import openpi.shared.download as download
|
||||
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
|
||||
>>> download.maybe_download(path)
|
||||
```
|
||||
|
||||
Converting pi0_base:
|
||||
```python
|
||||
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
|
||||
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
|
||||
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
|
||||
```
|
||||
|
||||
```python
|
||||
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
|
||||
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
|
||||
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
import torch
|
||||
from jax.sharding import SingleDeviceSharding
|
||||
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
|
||||
get_gemma_config,
|
||||
get_paligemma_config,
|
||||
)
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
||||
|
||||
# fmt: off
|
||||
# patch embeddings
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
|
||||
3, 2, 0, 1
|
||||
)
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
|
||||
# positional embeddings
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
|
||||
-1, config.vision_config.hidden_size
|
||||
)
|
||||
|
||||
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
|
||||
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
|
||||
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
|
||||
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
||||
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
||||
|
||||
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
||||
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
||||
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
||||
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
||||
|
||||
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
|
||||
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
|
||||
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
|
||||
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
|
||||
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
|
||||
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
|
||||
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
|
||||
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
|
||||
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
|
||||
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
|
||||
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
|
||||
|
||||
# multimodal projector
|
||||
|
||||
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
|
||||
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
|
||||
|
||||
# text decoder (gemma)
|
||||
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
|
||||
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
|
||||
|
||||
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
|
||||
# TODO verify correctness of layer norm loading
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
||||
|
||||
for i in range(config.text_config.num_hidden_layers):
|
||||
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
|
||||
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
||||
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
||||
|
||||
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
||||
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
||||
|
||||
# output projection.
|
||||
|
||||
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
|
||||
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
||||
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
||||
# mlp layers
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
||||
|
||||
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
|
||||
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
|
||||
|
||||
# fmt: on
|
||||
expert_dict = {}
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key not in [
|
||||
f"llm/final_norm_1/scale{suffix}",
|
||||
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/kv_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/q_einsum_1/w{suffix}",
|
||||
f"llm/layers/mlp_1/gating_einsum{suffix}",
|
||||
f"llm/layers/mlp_1/linear{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/scale{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
|
||||
]:
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
expert_dict[key] = value
|
||||
|
||||
return final_state_dict, expert_dict
|
||||
|
||||
|
||||
def slice_gemma_state_dict(state_dict, config, num_expert=1):
|
||||
# fmt: off
|
||||
# text decoder (gemma)
|
||||
# no embedding vector, the expert just has the decoder layers
|
||||
|
||||
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
|
||||
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
|
||||
|
||||
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
||||
|
||||
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
|
||||
# TODO verify correctness of layer norm loading
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
||||
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
||||
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
||||
|
||||
# output projection.
|
||||
|
||||
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
|
||||
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
|
||||
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
||||
# mlp layers
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
||||
|
||||
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
|
||||
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
|
||||
|
||||
# fmt: on
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
return final_state_dict
|
||||
|
||||
|
||||
def flatten_for_memory(tree, parent_key=""):
|
||||
out = {}
|
||||
for k, v in tree.items():
|
||||
new_key = f"{parent_key}/{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
out.update(flatten_for_memory(v, new_key))
|
||||
else:
|
||||
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
|
||||
return out
|
||||
|
||||
|
||||
def flatten_for_npz(tree, parent_key=""):
|
||||
out = {}
|
||||
for k, v in tree.items():
|
||||
new_key = f"{parent_key}/{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
out.update(flatten_for_npz(v, new_key))
|
||||
else:
|
||||
# bf16/f32 here?
|
||||
out[new_key] = np.array(v)
|
||||
return out
|
||||
|
||||
|
||||
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
|
||||
params_path = pathlib.Path(checkpoint_dir).resolve()
|
||||
checkpointer = ocp.PyTreeCheckpointer()
|
||||
|
||||
metadata = checkpointer.metadata(params_path)
|
||||
print("Metadata keys:", list(metadata.keys()))
|
||||
|
||||
params_name = "params"
|
||||
|
||||
item = {params_name: metadata[params_name]}
|
||||
device = jax.local_devices()[0] # Use the first local device
|
||||
sharding = SingleDeviceSharding(device)
|
||||
restored = checkpointer.restore(
|
||||
params_path,
|
||||
ocp.args.PyTreeRestore(
|
||||
item=item,
|
||||
restore_args=jax.tree_util.tree_map(
|
||||
lambda _: ocp.ArrayRestoreArgs(
|
||||
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
|
||||
sharding=sharding,
|
||||
),
|
||||
item,
|
||||
),
|
||||
transforms={},
|
||||
),
|
||||
)
|
||||
params = restored[params_name]
|
||||
|
||||
# get params for PaliGemma
|
||||
pali_params = params["PaliGemma"]
|
||||
del params["PaliGemma"]
|
||||
pali_params_flat = flatten_for_npz(pali_params)
|
||||
return {"paligemma_params": pali_params_flat, "projection_params": params}
|
||||
|
||||
|
||||
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
||||
"""Update dictionary keys by adding a prefix."""
|
||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, _tokenizer_id: str, output_path: str):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
keys = [
|
||||
"state_proj",
|
||||
"action_in_proj",
|
||||
"action_out_proj",
|
||||
"action_time_mlp_in",
|
||||
"action_time_mlp_out",
|
||||
]
|
||||
|
||||
projection_params = {}
|
||||
for key in keys:
|
||||
kernel_params = initial_params["projection_params"][key]["kernel"]
|
||||
bias_params = initial_params["projection_params"][key]["bias"]
|
||||
if isinstance(kernel_params, dict):
|
||||
weight = kernel_params["value"]
|
||||
bias = bias_params["value"]
|
||||
else:
|
||||
weight = kernel_params
|
||||
bias = bias_params
|
||||
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
|
||||
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
|
||||
|
||||
# Process PaliGemma weights
|
||||
paligemma_config = get_paligemma_config(precision)
|
||||
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
|
||||
initial_params["paligemma_params"], paligemma_config
|
||||
)
|
||||
|
||||
# Process Gemma weights (at this stage they are unused)
|
||||
gemma_config = get_gemma_config(precision)
|
||||
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
|
||||
|
||||
# Instantiate model from configs
|
||||
|
||||
if "pi0_aloha_sim" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=2,
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
)
|
||||
elif "pi0_aloha_towel" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=True,
|
||||
)
|
||||
elif "pi0_base" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=0,
|
||||
adapt_to_pi_aloha=False,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
# gemma_config=gemma_config, paligemma_config=paligemma_config)
|
||||
pi0_model = PI0Policy(pi0_config)
|
||||
|
||||
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
|
||||
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
|
||||
projection_params = update_keys_with_prefix(projection_params, "model.")
|
||||
|
||||
# load state dict
|
||||
torch_dtype = PRECISIONS[precision]
|
||||
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
|
||||
pi0_model = pi0_model.to(torch_dtype)
|
||||
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
|
||||
pi0_model.save_pretrained(output_path, safe_serialization=True)
|
||||
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
|
||||
|
||||
# assert that model loads properly
|
||||
del pi0_model
|
||||
PI0Policy.from_pretrained(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
|
||||
type=str,
|
||||
help="Path to the ocdbt checkpoint",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
default="float32",
|
||||
type=str,
|
||||
help="Precision identifier for model conversion - should match the base checkpoint precision.",
|
||||
)
|
||||
# tokenizer is identical to paligemma, it appears
|
||||
|
||||
parser.add_argument(
|
||||
"--tokenizer_hub_id",
|
||||
default="google/paligemma-3b-pt-224",
|
||||
type=str,
|
||||
help="Hub path to the tokenizer to save",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to save converted weights to",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_pi0_checkpoint(
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
precision=args.precision,
|
||||
_tokenizer_id=args.tokenizer_hub_id,
|
||||
output_path=args.output_path,
|
||||
)
|
||||
142
lerobot/common/policies/pi0/flex_attention.py
Normal file
142
lerobot/common/policies/pi0/flex_attention.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from packaging.version import Version
|
||||
|
||||
# TODO(Steven): Consider settings this a dependency constraint
|
||||
if Version(torch.__version__) > Version("2.5.0"):
|
||||
# Ffex attention is only available from torch 2.5 onwards
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_mask_mod_signature,
|
||||
_round_up_to_multiple,
|
||||
create_block_mask,
|
||||
create_mask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
|
||||
# @torch.compile(dynamic=False)
|
||||
def flex_attention_forward(
|
||||
attention_mask: torch.Tensor,
|
||||
batch_size: int,
|
||||
head_dim: int,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
scaling=None,
|
||||
):
|
||||
"""
|
||||
This is defined out of classes to make compile happy.
|
||||
"""
|
||||
|
||||
original_dtype = query_states.dtype
|
||||
num_att_heads = 8
|
||||
num_key_value_heads = 1
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
key_states = key_states[:, :, :, None, :]
|
||||
key_states = key_states.expand(
|
||||
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :]
|
||||
value_states = value_states.expand(
|
||||
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
query_states = query_states.to(torch.float32)
|
||||
key_states = key_states.to(torch.float32)
|
||||
value_states = value_states.to(torch.float32)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if causal_mask is not None:
|
||||
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
|
||||
|
||||
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
|
||||
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
|
||||
|
||||
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
|
||||
def mask_mod(b, h, q_idx, kv_idx):
|
||||
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
|
||||
return precomputed_mask[b][h][q_idx][kv_idx]
|
||||
|
||||
return mask_mod
|
||||
|
||||
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
||||
|
||||
block_size = 128
|
||||
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
||||
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
||||
|
||||
# *CRITICAL* we do need to expand here, else we get a CUDA index error
|
||||
|
||||
pad_q = q_len_rounded - q_len
|
||||
pad_k = kv_len_rounded - kv_len
|
||||
|
||||
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
||||
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
||||
|
||||
mask_4d = create_mask(
|
||||
mod_fn=mask_mod_fn_orig,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
|
||||
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
||||
block_mask = create_block_mask(
|
||||
mask_mod=mask_mod_fn_padded,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
BLOCK_SIZE=block_size,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
|
||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||
attn_output, _attention_weights = flex_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True, # because we shaped query/key states for GQA
|
||||
scale=head_dim**-0.5 if scaling is None else scaling,
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.to(dtype=original_dtype)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
|
||||
attn_output = attn_output.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
|
||||
)
|
||||
return attn_output
|
||||
732
lerobot/common/policies/pi0/modeling_pi0.py
Normal file
732
lerobot/common/policies/pi0/modeling_pi0.py
Normal file
@@ -0,0 +1,732 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||||
|
||||
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
|
||||
Install pi0 extra dependencies:
|
||||
```bash
|
||||
pip install -e ".[pi0]"
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
|
||||
pretrained with VLM default parameters before pi0 finetuning:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertConfig,
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
||||
return gamma1 / (gamma1 + gamma2)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
return att_2d_masks
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=-1):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
new_vector[..., :current_dim] = vector
|
||||
return new_vector
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class PI0Policy(PreTrainedPolicy):
|
||||
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = PI0Config
|
||||
name = "pi0"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_is_pad")
|
||||
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
# For backward pass
|
||||
loss = losses.mean()
|
||||
# For logging
|
||||
loss_dict["l2_loss"] = loss.item()
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
||||
"""
|
||||
images = []
|
||||
img_masks = []
|
||||
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
|
||||
# Preprocess image features present in the batch
|
||||
for key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
# Create image features not present in the batch
|
||||
# as fully 0 padded images.
|
||||
for num_empty_cameras in range(len(missing_img_keys)):
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
break
|
||||
img = torch.ones_like(img) * -1
|
||||
mask = torch.zeros_like(mask)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_ROBOT].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
|
||||
class PI0FlowMatching(nn.Module):
|
||||
"""
|
||||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||||
|
||||
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
┌──────────────────────────────┐
|
||||
│ actions │
|
||||
│ ▲ │
|
||||
│ ┌┴─────┐ │
|
||||
│ kv cache │Gemma │ │
|
||||
│ ┌──────────►│Expert│ │
|
||||
│ │ │ │ │
|
||||
│ ┌┴────────┐ │x 10 │ │
|
||||
│ │ │ └▲──▲──┘ │
|
||||
│ │PaliGemma│ │ │ │
|
||||
│ │ │ │ robot state │
|
||||
│ │ │ noise │
|
||||
│ └▲──▲─────┘ │
|
||||
│ │ │ │
|
||||
│ │ image(s) │
|
||||
│ language tokens │
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
attention_implementation=self.config.attention_implementation,
|
||||
)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||||
|
||||
# Projections are float32
|
||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
||||
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
||||
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
||||
|
||||
self.set_requires_grad()
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for PaliGemma transformer processing.
|
||||
"""
|
||||
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# TODO: remove for loop
|
||||
for (
|
||||
img,
|
||||
img_mask,
|
||||
) in zip(images, img_masks, strict=False):
|
||||
img_emb = self.paligemma_with_expert.embed_image(img)
|
||||
img_emb = img_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask)
|
||||
|
||||
# Create attention masks so that image tokens attend to each other
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
|
||||
# Normalize language embeddings
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
# full attention between image and language inputs
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, state, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Embed state
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb.to(dtype=torch.bfloat16)
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
dtype = state_emb.dtype
|
||||
device = state_emb.device
|
||||
|
||||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1]
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
action_emb = self.action_in_proj(noisy_actions)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
action_time_emb = F.silu(action_time_emb) # swish == silu
|
||||
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
||||
# Original openpi code, upcast attention output
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Compute image and language key value cache
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
return v_t
|
||||
417
lerobot/common/policies/pi0/paligemma_with_expert.py
Normal file
417
lerobot/common/policies/pi0/paligemma_with_expert.py
Normal file
@@ -0,0 +1,417 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.version
|
||||
from pytest import Cache
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
GemmaForCausalLM,
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.common.policies.pi0.flex_attention import flex_attention_forward
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertConfig(PretrainedConfig):
|
||||
model_type = "PaliGemmaWithExpertModel"
|
||||
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paligemma_config: dict | None = None,
|
||||
gemma_expert_config: dict | None = None,
|
||||
freeze_vision_encoder: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
attention_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_implementation = attention_implementation
|
||||
|
||||
if paligemma_config is None:
|
||||
# Default config from Pi0
|
||||
self.paligemma_config = CONFIG_MAPPING["paligemma"](
|
||||
transformers_version="4.48.1",
|
||||
_vocab_size=257152,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
hidden_size=2048,
|
||||
image_token_index=257152,
|
||||
model_type="paligemma",
|
||||
pad_token_id=0,
|
||||
projection_dim=2048,
|
||||
text_config={
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 16384,
|
||||
"model_type": "gemma",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 18,
|
||||
"num_image_tokens": 256,
|
||||
"num_key_value_heads": 1,
|
||||
"torch_dtype": "float32",
|
||||
"vocab_size": 257152,
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_image_tokens": 256,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"torch_dtype": "float32",
|
||||
"vision_use_head": False,
|
||||
},
|
||||
)
|
||||
elif isinstance(self.paligemma_config, dict):
|
||||
# Override Pi0 default config for PaliGemma
|
||||
if "model_type" not in gemma_expert_config:
|
||||
paligemma_config["model_type"] = "paligemma"
|
||||
|
||||
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
self.paligemma_config = cfg_cls(**paligemma_config)
|
||||
|
||||
if gemma_expert_config is None:
|
||||
# Default config from Pi0
|
||||
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
hidden_size=1024,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=4096,
|
||||
max_position_embeddings=8192,
|
||||
model_type="gemma",
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=18,
|
||||
num_key_value_heads=1,
|
||||
pad_token_id=0,
|
||||
rms_norm_eps=1e-06,
|
||||
rope_theta=10000.0,
|
||||
torch_dtype="float32",
|
||||
transformers_version="4.48.1",
|
||||
use_cache=True,
|
||||
vocab_size=257152,
|
||||
)
|
||||
elif isinstance(self.gemma_expert_config, dict):
|
||||
# Override Pi0 default config for Gemma Expert
|
||||
if "model_type" not in gemma_expert_config:
|
||||
gemma_expert_config["model_type"] = "gemma"
|
||||
|
||||
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.train_expert_only and not self.freeze_vision_encoder:
|
||||
raise ValueError(
|
||||
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
|
||||
)
|
||||
|
||||
if self.attention_implementation not in ["eager", "fa2", "flex"]:
|
||||
raise ValueError(
|
||||
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
config_class = PaliGemmaWithExpertConfig
|
||||
|
||||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||
# Remove unused embed_tokens
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_like_physical_intelligence()
|
||||
self.set_requires_grad()
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for params in self.paligemma.vision_tower.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
if self.config.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
for params in self.paligemma.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
|
||||
if self.config.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def to_bfloat16_like_physical_intelligence(self):
|
||||
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
|
||||
|
||||
params_to_change_dtype = [
|
||||
"language_model.model.layers",
|
||||
"gemma_expert.model.layers",
|
||||
"vision_tower",
|
||||
"multi_modal",
|
||||
]
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch.bfloat16)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.get_image_features(image)
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.model.embed_tokens(tokens)
|
||||
|
||||
# TODO: break down this huge forward into modules or functions
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||
inputs_embeds: List[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
):
|
||||
models = [self.paligemma.language_model.model, self.gemma_expert.model]
|
||||
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
head_dim = self.paligemma.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is None:
|
||||
continue
|
||||
layer = models[i].layers[layer_idx]
|
||||
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
|
||||
query_states = apply_rope(query_states, position_ids)
|
||||
key_states = apply_rope(key_states, position_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
att_output = attention_interface(
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
|
||||
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
|
||||
if hidden_states is not None:
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
|
||||
|
||||
# TODO: first dropout (by default 0.0)
|
||||
|
||||
# first residual
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
# TODO: second dropout (by default 0.0)
|
||||
|
||||
# second residual
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
if self.config.attention_implementation == "fa2":
|
||||
attention_interface = self.flash_attention_forward
|
||||
elif self.config.attention_implementation == "flex":
|
||||
attention_interface = flex_attention_forward
|
||||
else:
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def flash_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
||||
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
||||
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
@@ -1,75 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A protocol that all policies should follow.
|
||||
|
||||
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
||||
subclass a base class.
|
||||
|
||||
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
|
||||
how to implement new policies.
|
||||
"""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Policy(Protocol):
|
||||
"""The required interface for implementing a policy.
|
||||
|
||||
We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin.
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
"""To be called whenever the environment is reset.
|
||||
|
||||
Does things like clearing caches.
|
||||
"""
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
||||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PolicyWithUpdate(Policy, Protocol):
|
||||
def update(self):
|
||||
"""An update method that is to be called after a training optimization step.
|
||||
|
||||
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
|
||||
target model, or incrementing an internal buffer).
|
||||
"""
|
||||
199
lerobot/common/policies/pretrained.py
Normal file
199
lerobot/common/policies/pretrained.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor
|
||||
from safetensors.torch import save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
|
||||
DEFAULT_POLICY_CARD = """
|
||||
---
|
||||
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
||||
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
||||
{{ card_data }}
|
||||
---
|
||||
|
||||
This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot):
|
||||
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
||||
"""
|
||||
|
||||
|
||||
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
"""
|
||||
Base class for policy models.
|
||||
"""
|
||||
|
||||
config_class: None
|
||||
name: None
|
||||
|
||||
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, PreTrainedConfig):
|
||||
raise ValueError(
|
||||
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
||||
"`PreTrainedConfig`. To create a model from a pretrained model use "
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not getattr(cls, "config_class", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: Type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""
|
||||
The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are
|
||||
deactivated). To train it, you should first set it back in training mode with `policy.train()`.
|
||||
"""
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
policy.to(config.device)
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
||||
load_model_as_safetensor(model, model_file, strict=strict)
|
||||
if map_location != "cpu":
|
||||
logging.warning(
|
||||
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
||||
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
||||
" This leads to a slower loading time."
|
||||
" Please update safetensors to version 0.4.3 or above for improved performance."
|
||||
)
|
||||
model.to(map_location)
|
||||
else:
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
return model
|
||||
|
||||
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
||||
# card = ModelCard.from_template(
|
||||
# card_data=self._hub_mixin_info.model_card_data,
|
||||
# template_str=self._hub_mixin_info.model_card_template,
|
||||
# repo_url=self._hub_mixin_info.repo_url,
|
||||
# docs_url=self._hub_mixin_info.docs_url,
|
||||
# **kwargs,
|
||||
# )
|
||||
# return card
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optim_params(self) -> dict:
|
||||
"""
|
||||
Returns the policy-specific parameters dict to be passed on to the optimizer.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self):
|
||||
"""To be called whenever the environment is reset.
|
||||
|
||||
Does things like clearing caches.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
|
||||
@abc.abstractmethod
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
batch (dict[str, Tensor]): _description_
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
|
||||
is a Tensor, all other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -16,9 +16,14 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("tdmpc")
|
||||
@dataclass
|
||||
class TDMPCConfig:
|
||||
class TDMPCConfig(PreTrainedConfig):
|
||||
"""Configuration class for TDMPCPolicy.
|
||||
|
||||
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
|
||||
@@ -71,7 +76,7 @@ class TDMPCConfig:
|
||||
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
|
||||
be zero.
|
||||
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
|
||||
trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
|
||||
trajectory values (this is the λ coefficient in eqn 4 of FOWM).
|
||||
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
|
||||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||
elites, when updating the gaussian parameters for CEM.
|
||||
@@ -102,27 +107,19 @@ class TDMPCConfig:
|
||||
"""
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
n_action_repeats: int = 2
|
||||
horizon: int = 5
|
||||
n_action_steps: int = 1
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ENV": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
@@ -159,32 +156,27 @@ class TDMPCConfig:
|
||||
# Target model.
|
||||
target_model_momentum: float = 0.995
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 3e-4
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
if len(image_keys) > 0:
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
super().__post_init__()
|
||||
|
||||
# Input validation (not exhaustive).
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
)
|
||||
if self.output_normalization_modes != {"action": "min_max"}:
|
||||
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
|
||||
raise ValueError(
|
||||
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
|
||||
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||
"information."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if self.n_action_steps > 1:
|
||||
if self.n_action_repeats != 1:
|
||||
raise ValueError(
|
||||
@@ -194,3 +186,35 @@ class TDMPCConfig:
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(lr=self.optimizer_lr)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# There should only be one image key.
|
||||
if len(self.image_features) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {self.image_features}."
|
||||
)
|
||||
|
||||
if len(self.image_features) > 0:
|
||||
image_ft = next(iter(self.image_features.values()))
|
||||
if image_ft.shape[-2] != image_ft.shape[-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(self.horizon + 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return list(range(self.horizon))
|
||||
|
||||
@@ -33,21 +33,16 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
|
||||
|
||||
class TDMPCPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "tdmpc"],
|
||||
):
|
||||
class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""Implementation of TD-MPC learning + inference.
|
||||
|
||||
Please note several warnings for this policy.
|
||||
@@ -65,11 +60,10 @@ class TDMPCPolicy(
|
||||
match our xarm environment.
|
||||
"""
|
||||
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(
|
||||
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
@@ -77,42 +71,31 @@ class TDMPCPolicy(
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = TDMPCConfig()
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = TDMPCTOLD(config)
|
||||
self.model_target = deepcopy(self.model)
|
||||
for param in self.model_target.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
self._use_image = False
|
||||
self._use_env_state = False
|
||||
if len(image_keys) > 0:
|
||||
assert len(image_keys) == 1
|
||||
self._use_image = True
|
||||
self.input_image_key = image_keys[0]
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
self._queues = None
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
|
||||
@@ -122,21 +105,21 @@ class TDMPCPolicy(
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self._use_image:
|
||||
if self.config.image_features:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
self._prev_mean = None
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -151,9 +134,9 @@ class TDMPCPolicy(
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
encode_keys = []
|
||||
if self._use_image:
|
||||
if self.config.image_features:
|
||||
encode_keys.append("observation.image")
|
||||
if self._use_env_state:
|
||||
if self.config.env_state_feature:
|
||||
encode_keys.append("observation.environment_state")
|
||||
encode_keys.append("observation.state")
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
@@ -196,7 +179,7 @@ class TDMPCPolicy(
|
||||
self.config.horizon,
|
||||
self.config.n_pi_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
self.config.action_feature.shape[0],
|
||||
device=device,
|
||||
)
|
||||
if self.config.n_pi_samples > 0:
|
||||
@@ -215,7 +198,7 @@ class TDMPCPolicy(
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
@@ -228,7 +211,7 @@ class TDMPCPolicy(
|
||||
self.config.horizon,
|
||||
self.config.n_gaussian_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
self.config.action_feature.shape[0],
|
||||
device=std.device,
|
||||
)
|
||||
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||
@@ -322,7 +305,7 @@ class TDMPCPolicy(
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
@@ -330,16 +313,16 @@ class TDMPCPolicy(
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
@@ -347,7 +330,7 @@ class TDMPCPolicy(
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
@@ -360,7 +343,7 @@ class TDMPCPolicy(
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self._use_image else "observation.environment_state"
|
||||
"observation.image" if self.config.image_features else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
@@ -515,17 +498,16 @@ class TDMPCPolicy(
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"V_value_loss": v_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"loss": loss,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
}
|
||||
)
|
||||
|
||||
# Undo (b, t) -> (t, b).
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
return info
|
||||
return loss, info
|
||||
|
||||
def update(self):
|
||||
"""Update the target model's parameters with an EMA step."""
|
||||
@@ -535,6 +517,7 @@ class TDMPCPolicy(
|
||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||
|
||||
|
||||
# TODO(Steven): forward implementation missing
|
||||
class TDMPCTOLD(nn.Module):
|
||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||
|
||||
@@ -543,7 +526,7 @@ class TDMPCTOLD(nn.Module):
|
||||
self.config = config
|
||||
self._encoder = TDMPCObservationEncoder(config)
|
||||
self._dynamics = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -554,7 +537,7 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self._reward = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -569,12 +552,12 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.output_shapes["action"][0]),
|
||||
nn.Linear(config.mlp_dim, config.action_feature.shape[0]),
|
||||
)
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -615,9 +598,9 @@ class TDMPCTOLD(nn.Module):
|
||||
|
||||
self.apply(_apply_fn)
|
||||
for m in [self._reward, *self._Qs]:
|
||||
assert isinstance(
|
||||
m[-1], nn.Linear
|
||||
), "Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
assert isinstance(m[-1], nn.Linear), (
|
||||
"Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
)
|
||||
nn.init.zeros_(m[-1].weight)
|
||||
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
|
||||
|
||||
@@ -714,10 +697,13 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if "observation.image" in config.input_shapes:
|
||||
if config.image_features:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
||||
next(iter(config.image_features.values())).shape[0],
|
||||
config.image_encoder_hidden_dim,
|
||||
7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
@@ -727,9 +713,8 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
|
||||
out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
@@ -738,19 +723,19 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
)
|
||||
if "observation.state" in config.input_shapes:
|
||||
|
||||
if config.robot_state_feature:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
|
||||
if config.env_state_feature:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -765,12 +750,16 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
"""
|
||||
feat = []
|
||||
# NOTE: Order of observations matters here.
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
if self.config.image_features:
|
||||
feat.append(
|
||||
flatten_forward_unflatten(
|
||||
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
|
||||
)
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
|
||||
if self.config.robot_state_feature:
|
||||
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -47,3 +48,20 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
|
||||
def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
|
||||
"""
|
||||
Calculates the output shape of a PyTorch module given an input shape.
|
||||
|
||||
Args:
|
||||
module (nn.Module): a PyTorch module
|
||||
input_shape (tuple): A tuple representing the input shape, e.g., (batch_size, channels, height, width)
|
||||
|
||||
Returns:
|
||||
tuple: The output shape of the module.
|
||||
"""
|
||||
dummy_input = torch.zeros(size=input_shape)
|
||||
with torch.inference_mode():
|
||||
output = module(dummy_input)
|
||||
return tuple(output.shape)
|
||||
|
||||
@@ -18,9 +18,15 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vqbet")
|
||||
@dataclass
|
||||
class VQBeTConfig:
|
||||
class VQBeTConfig(PreTrainedConfig):
|
||||
"""Configuration class for VQ-BeT.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
@@ -60,7 +66,7 @@ class VQBeTConfig:
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
@@ -90,26 +96,13 @@ class VQBeTConfig:
|
||||
n_action_pred_token: int = 3
|
||||
action_chunk_size: int = 5
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -139,29 +132,69 @@ class VQBeTConfig:
|
||||
bet_softmax_temperature: float = 0.1
|
||||
sequentially_select: bool = False
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-6
|
||||
optimizer_vqvae_lr: float = 1e-3
|
||||
optimizer_vqvae_weight_decay: float = 1e-4
|
||||
scheduler_warmup_steps: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
super().__post_init__()
|
||||
|
||||
# Input validation (not exhaustive).
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> VQBeTSchedulerConfig:
|
||||
return VQBeTSchedulerConfig(
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_vqvae_training_steps=self.n_vqvae_training_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# Note: this check was previously performed inside VQBeTRgbEncoder in the form of
|
||||
# assert len(image_keys) == 1
|
||||
if not len(self.image_features) == 1:
|
||||
raise ValueError("You must provide only one image among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from collections import deque
|
||||
from typing import Callable, List
|
||||
@@ -26,29 +25,23 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
|
||||
class VQBeTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vqbet"],
|
||||
):
|
||||
class VQBeTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
|
||||
"""
|
||||
|
||||
config_class = VQBeTConfig
|
||||
name = "vqbet"
|
||||
|
||||
def __init__(
|
||||
@@ -63,26 +56,64 @@ class VQBeTPolicy(
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = VQBeTConfig()
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self._queues = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
vqvae_params = (
|
||||
list(self.vqbet.action_head.vqvae_model.encoder.parameters())
|
||||
+ list(self.vqbet.action_head.vqvae_model.decoder.parameters())
|
||||
+ list(self.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
||||
)
|
||||
decay_params, no_decay_params = self.vqbet.policy.configure_parameters()
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(self.vqbet.rgb_encoder.parameters())
|
||||
+ list(self.vqbet.state_projector.parameters())
|
||||
+ list(self.vqbet.rgb_feature_projector.parameters())
|
||||
+ [self.vqbet.action_token]
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
||||
)
|
||||
|
||||
if self.config.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
|
||||
return [
|
||||
{
|
||||
"params": decay_params,
|
||||
},
|
||||
{
|
||||
"params": vqvae_params,
|
||||
"weight_decay": self.config.optimizer_vqvae_weight_decay,
|
||||
"lr": self.config.optimizer_vqvae_lr,
|
||||
},
|
||||
{
|
||||
"params": no_decay_params,
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
@@ -105,7 +136,7 @@ class VQBeTPolicy(
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -127,11 +158,11 @@ class VQBeTPolicy(
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
@@ -141,16 +172,16 @@ class VQBeTPolicy(
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
)
|
||||
return {
|
||||
"loss": loss,
|
||||
return loss, {
|
||||
"n_different_codes": n_different_codes,
|
||||
"n_different_combinations": n_different_combinations,
|
||||
"recon_l1_error": recon_l1_error,
|
||||
}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
|
||||
_, loss_dict = self.vqbet(batch, rollout=False)
|
||||
loss = loss_dict.pop("loss")
|
||||
|
||||
return loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
@@ -288,14 +319,14 @@ class VQBeTModel(nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.rgb_encoder = VQBeTRgbEncoder(config)
|
||||
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
self.num_images = len(self.config.image_features)
|
||||
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
|
||||
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
|
||||
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
|
||||
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
@@ -313,7 +344,7 @@ class VQBeTModel(nn.Module):
|
||||
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
||||
)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.images"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
@@ -350,17 +381,22 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
# get action features (pass through GPT)
|
||||
features = self.policy(input_tokens)
|
||||
# len(self.config.input_shapes) is the number of different observation modes. this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
|
||||
self.config.input_shapes
|
||||
# len(self.config.input_features) is the number of different observation modes.
|
||||
# this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
|
||||
self.config.input_features
|
||||
)
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
|
||||
# Thus, it predict historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||
features = torch.cat(
|
||||
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
||||
)
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
||||
# mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
|
||||
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||
if len_additional_action_token > 0:
|
||||
features = torch.cat(
|
||||
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
||||
)
|
||||
else:
|
||||
features = features[:, historical_act_pred_index]
|
||||
# pass through action head
|
||||
action_head_output = self.action_head(features)
|
||||
# if rollout, VQ-BeT don't calculate loss
|
||||
@@ -387,7 +423,7 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
|
||||
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
|
||||
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`.
|
||||
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.action_feature.shape[0]`.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
@@ -414,7 +450,7 @@ class VQBeTHead(nn.Module):
|
||||
self.vqvae_model.vqvae_num_layers
|
||||
* self.config.vqvae_n_embed
|
||||
* config.action_chunk_size
|
||||
* config.output_shapes["action"][0],
|
||||
* config.action_feature.shape[0],
|
||||
],
|
||||
)
|
||||
# loss
|
||||
@@ -448,10 +484,10 @@ class VQBeTHead(nn.Module):
|
||||
param.requires_grad = False
|
||||
return loss, n_different_codes, n_different_combinations, recon_l1_error
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
def forward(self, x, **kwargs) -> dict:
|
||||
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
||||
N, T, _ = x.shape
|
||||
# we calculate N and T side parallely. Thus, the dimensions would be
|
||||
# we calculate N and T side parallelly. Thus, the dimensions would be
|
||||
# (batch size * number of action query tokens, action chunk size, action dimension)
|
||||
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
||||
|
||||
@@ -501,7 +537,7 @@ class VQBeTHead(nn.Module):
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
NT, _G, choices = cbet_probs.shape
|
||||
sampled_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||
"(NT G) 1 -> NT G",
|
||||
@@ -544,7 +580,7 @@ class VQBeTHead(nn.Module):
|
||||
"decoded_action": decoded_action,
|
||||
}
|
||||
|
||||
def loss_fn(self, pred, target, **kwargs):
|
||||
def loss_fn(self, pred, target, **_kwargs):
|
||||
"""
|
||||
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
||||
|
||||
@@ -571,7 +607,7 @@ class VQBeTHead(nn.Module):
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each ground truth action.
|
||||
with torch.no_grad():
|
||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
_state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
|
||||
# Now we can compute the loss.
|
||||
|
||||
@@ -618,84 +654,6 @@ class VQBeTHead(nn.Module):
|
||||
return loss_dict
|
||||
|
||||
|
||||
class VQBeTOptimizer(torch.optim.Adam):
|
||||
def __init__(self, policy, cfg):
|
||||
vqvae_params = (
|
||||
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
||||
)
|
||||
decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.rgb_encoder.parameters())
|
||||
+ list(policy.vqbet.state_projector.parameters())
|
||||
+ list(policy.vqbet.rgb_feature_projector.parameters())
|
||||
+ [policy.vqbet.action_token]
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
||||
)
|
||||
|
||||
if cfg.policy.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
|
||||
optim_groups = [
|
||||
{
|
||||
"params": decay_params,
|
||||
"weight_decay": cfg.training.adam_weight_decay,
|
||||
"lr": cfg.training.lr,
|
||||
},
|
||||
{
|
||||
"params": vqvae_params,
|
||||
"weight_decay": 0.0001,
|
||||
"lr": cfg.training.vqvae_lr,
|
||||
},
|
||||
{
|
||||
"params": no_decay_params,
|
||||
"weight_decay": 0.0,
|
||||
"lr": cfg.training.lr,
|
||||
},
|
||||
]
|
||||
super().__init__(
|
||||
optim_groups,
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
)
|
||||
|
||||
|
||||
class VQBeTScheduler(nn.Module):
|
||||
def __init__(self, optimizer, cfg):
|
||||
super().__init__()
|
||||
n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
|
||||
|
||||
num_warmup_steps = cfg.training.lr_warmup_steps
|
||||
num_training_steps = cfg.training.offline_steps
|
||||
num_cycles = 0.5
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < n_vqvae_training_steps:
|
||||
return float(1)
|
||||
else:
|
||||
current_step = current_step - n_vqvae_training_steps
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
||||
|
||||
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
def step(self):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
|
||||
class VQBeTRgbEncoder(nn.Module):
|
||||
"""Encode an RGB image into a 1D feature vector.
|
||||
|
||||
@@ -738,19 +696,15 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
assert len(image_keys) == 1
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
@@ -810,6 +764,7 @@ def _replace_submodules(
|
||||
return root_module
|
||||
|
||||
|
||||
# TODO(Steven): Missing implementation of forward, is it maybe vqvae_forward?
|
||||
class VqVae(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -820,7 +775,7 @@ class VqVae(nn.Module):
|
||||
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
|
||||
The vq_layer uses residual VQs.
|
||||
|
||||
This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1),
|
||||
This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1),
|
||||
as well as functions to help BeT training part in training phase 2.
|
||||
"""
|
||||
|
||||
@@ -839,7 +794,7 @@ class VqVae(nn.Module):
|
||||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
@@ -851,7 +806,7 @@ class VqVae(nn.Module):
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -867,9 +822,9 @@ class VqVae(nn.Module):
|
||||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
@@ -924,13 +879,13 @@ class FocalLoss(nn.Module):
|
||||
self.gamma = gamma
|
||||
self.size_average = size_average
|
||||
|
||||
def forward(self, input, target):
|
||||
if len(input.shape) == 3:
|
||||
N, T, _ = input.shape
|
||||
logpt = F.log_softmax(input, dim=-1)
|
||||
def forward(self, forward_input, target):
|
||||
if len(forward_input.shape) == 3:
|
||||
N, T, _ = forward_input.shape
|
||||
logpt = F.log_softmax(forward_input, dim=-1)
|
||||
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
||||
elif len(input.shape) == 2:
|
||||
logpt = F.log_softmax(input, dim=-1)
|
||||
elif len(forward_input.shape) == 2:
|
||||
logpt = F.log_softmax(forward_input, dim=-1)
|
||||
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
||||
pt = logpt.exp()
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user