Compare commits

..

31 Commits

Author SHA1 Message Date
ef45ea9649 single arm refactory
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-26 21:08:22 +08:00
bc351a0134 change pose control api to canfd
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-19 14:55:54 +08:00
68986f6fc0 update some readme
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-19 11:49:19 +08:00
2f124e34de redefine init joint pos
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-17 14:56:23 +08:00
c28e774234 optimaze the speed of end pose control
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-13 20:17:31 +08:00
80b1a97e4c change opencv to realsense camera
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-12 17:56:21 +08:00
f4fec8f51c change pose control api
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-11 16:17:39 +08:00
f4f82c916f some bug still
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-11 15:20:14 +08:00
ecbe154709 no change
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-09 16:24:00 +08:00
d00c154db9 update state
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-09 16:23:09 +08:00
55f284b306 mix control fix bug
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-09 10:58:28 +08:00
cf8df17d3a add realman shadow src
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-07 11:29:43 +08:00
e079566597 xbox controller demo
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-07 11:22:05 +08:00
83d6419d70 手柄控制第一次提交 2025-06-05 21:56:52 +08:00
a0ec9e1cb1 single arm test
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2025-06-05 15:50:26 +08:00
3eede4447d dual arm test 2025-06-05 15:50:18 +08:00
9c6a7d9701 new md 2025-06-05 15:50:11 +08:00
7b201773f3 single arm test 2025-06-05 15:49:57 +08:00
mshukor
bfd26eef5a Add SmolVLA (#1175)
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: fracapuano <francesco.capuano@huggingface.co>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com>
Co-authored-by: Remi <remi.cadene@huggingface.co>
2025-06-03 17:11:50 +02:00
pre-commit-ci[bot]
1537d0ab90 [pre-commit.ci] pre-commit autoupdate (#1048)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
2025-06-02 19:30:39 +02:00
Adil Zouitine
2be7f3a3ff (hotfix): nightly CI by clipping pymunk version below 7.0.0 (#1182) 2025-06-02 13:18:02 +02:00
Adil Zouitine
0cf864870c [Fix] Unpin torch beyond 2.6.0 & torchcodec beyond 0.2.1 (#1127) 2025-05-28 16:54:20 +02:00
mshukor
1786916a16 Update README.md (#1163) 2025-05-27 11:50:43 +02:00
mshukor
0507ad4f68 Update README.md (#1160) 2025-05-27 11:45:07 +02:00
Ragnar
bed90e3a41 fix: typos and grammar (#1148) 2025-05-25 17:20:45 +02:00
Francesco Capuano
6163daaaa4 Fix: emptying action queue between resets (#1117) 2025-05-22 21:37:21 +02:00
Pepijn
8e2a394442 Add editable -e for feetech install command (#1133) 2025-05-20 18:51:21 +02:00
masato-ka
a445d9c9da bug fix for #1071 When --display_data=true, Failed running control_robot. (#1073) 2025-05-09 16:53:40 +02:00
CharlesCNorton
f24030d4d8 Update 12_use_so101.md (#1081)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2025-05-09 11:04:25 +02:00
Mishig
7598aeaad7 Update 10_use_so100.md; use diff syntax (#944)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2025-05-09 11:01:12 +02:00
Pepijn
4485cc0b5b docs: minor corrections and clean-up (#1089) 2025-05-09 11:00:25 +02:00
193 changed files with 18097 additions and 11673 deletions

View File

@@ -40,24 +40,24 @@ jobs:
git lfs install
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
lfs: true
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@v3
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push CPU
uses: docker/build-push-action@v5
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
with:
context: .
file: ./docker/lerobot-cpu/Dockerfile
@@ -78,24 +78,24 @@ jobs:
git lfs install
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
lfs: true
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@v3
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
uses: docker/build-push-action@v5
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
with:
context: .
file: ./docker/lerobot-gpu/Dockerfile
@@ -110,23 +110,23 @@ jobs:
group: aws-general-8-plus
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@v3
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU dev
uses: docker/build-push-action@v5
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
with:
context: .
file: ./docker/lerobot-gpu-dev/Dockerfile

View File

@@ -33,7 +33,7 @@ jobs:
runs-on:
group: aws-general-8-plus
container:
image: huggingface/lerobot-cpu:latest
image: huggingface/lerobot-cpu:latest # zizmor: ignore[unpinned-images]
options: --shm-size "16gb"
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -60,7 +60,7 @@ jobs:
CUDA_VISIBLE_DEVICES: "0"
TEST_TYPE: "single_gpu"
container:
image: huggingface/lerobot-gpu:latest
image: huggingface/lerobot-gpu:latest # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}

View File

@@ -33,12 +33,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1
with:
python-version: ${{ env.PYTHON_VERSION }}
@@ -64,9 +64,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: typos-action
uses: crate-ci/typos@v1.29.10
uses: crate-ci/typos@db35ee91e80fbb447f33b0e5fbddb24d2a1a884f # v1.29.10

View File

@@ -35,7 +35,7 @@ jobs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
@@ -64,17 +64,17 @@ jobs:
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Build Docker image
uses: docker/build-push-action@v5
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
with:
file: ${{ matrix.docker-file }}
context: .

View File

@@ -50,7 +50,7 @@ jobs:
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
@@ -62,7 +62,7 @@ jobs:
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
- name: Install uv and python
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
with:
enable-cache: true
version: ${{ env.UV_VERSION }}
@@ -85,7 +85,7 @@ jobs:
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
@@ -94,7 +94,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y ffmpeg
- name: Install uv and python
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
with:
enable-cache: true
version: ${{ env.UV_VERSION }}
@@ -117,7 +117,7 @@ jobs:
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
@@ -129,7 +129,7 @@ jobs:
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
- name: Install uv and python
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
with:
enable-cache: true
version: ${{ env.UV_VERSION }}

View File

@@ -24,12 +24,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
persist-credentials: false
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35
with:
extra_args: --only-verified

1
.gitignore vendored
View File

@@ -26,7 +26,6 @@ outputs
# VS Code
.vscode
.devcontainer
# HPC
nautilus/*.yaml

View File

@@ -37,19 +37,18 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/adhtruong/mirrors-typos
rev: v1.31.1
rev: v1.32.0
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
rev: v3.20.0
hooks:
- id: pyupgrade
# Exclude generated protobuf files
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.5
rev: v0.11.11
hooks:
- id: ruff
args: [--fix]
@@ -58,12 +57,12 @@ repos:
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.3
rev: v8.26.0
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.5.2
rev: v1.8.0
hooks:
- id: zizmor

View File

@@ -360,7 +360,7 @@ with profile(
If you want, you can cite this work with:
```bibtex
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
@@ -408,19 +408,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
year={2024}
}
```
- [HIL-SERL](https://hil-serl.github.io/)
```bibtex
@Article{luo2024hilserl,
title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine},
year={2024},
eprint={2410.21845},
archivePrefix={arXiv},
primaryClass={cs.RO}
}
```
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline)

View File

@@ -96,8 +96,8 @@ Reconnect the usb cable.
#### Update config file
Now that you have your ports, update the **port** default values of [`SO101RobotConfig`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robot_devices/robots/configs.py).
You will find something like, update the `port` values with your actual motor ports:
```python
You will find a class called `so101` where you can update the `port` values with your actual motor ports:
```diff
@RobotConfig.register_subclass("so101")
@dataclass
class So101RobotConfig(ManipulatorRobotConfig):
@@ -110,7 +110,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
- port="/dev/tty.usbmodem58760431091",
+ port="{ADD YOUR LEADER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -127,7 +128,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
- port="/dev/tty.usbmodem585A0076891",
+ port="{ADD YOUR FOLLOWER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -297,7 +299,7 @@ Remove all support material from the 3D-printed parts, the easiest way to do thi
##### Wiring
- Attach the motor controller on the back.
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themself and stay in place.
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themselves and stay in place.
<div class="video-container">
<video controls width="600">

View File

@@ -83,7 +83,7 @@ camera_01_frame_000047.png
Note: Some cameras may take a few seconds to warm up, and the first frame might be black or green.
Now that you have the camera indexes, you should specify the camera's in the config. TODO(pepijn): add more info about setting camera config, rotate etc..
Now that you have the camera indexes, you should specify the camera's in the config.
### Use your phone
<hfoptions id="use phone">
@@ -152,7 +152,7 @@ If everything is set up correctly, you can proceed with the rest of the tutorial
## Teleoperate with cameras
We can now teleoperate again while at the same time visualizing the camera's and joint positions with `rerun`.
We can now teleoperate again while at the same time visualizing the cameras and joint positions with `rerun`.
```bash
python lerobot/scripts/control_robot.py \

View File

@@ -55,7 +55,7 @@ conda install ffmpeg -c conda-forge
Install 🤗 LeRobot:
```bash
cd lerobot && pip install ".[feetech]"
cd lerobot && pip install -e ".[feetech]"
```
## Troubleshooting

View File

@@ -128,7 +128,7 @@ sudo chmod 666 /dev/ttyACM1
#### d. Update config file
IMPORTANTLY: Now that you have your ports, update the **port** default values of [`SO100RobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
```python
```diff
@RobotConfig.register_subclass("so100")
@dataclass
class So100RobotConfig(ManipulatorRobotConfig):
@@ -141,7 +141,8 @@ class So100RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
- port="/dev/tty.usbmodem58760431091",
+ port="{ADD YOUR LEADER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -158,7 +159,8 @@ class So100RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
- port="/dev/tty.usbmodem585A0076891",
+ port="{ADD YOUR FOLLOWER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],

View File

@@ -141,7 +141,7 @@ python lerobot/scripts/configure_motor.py \
--ID 1
```
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
Note: These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
Then unplug your motor and plug the second motor and set its ID to 2.
```bash

View File

@@ -61,7 +61,7 @@ conda install ffmpeg -c conda-forge
Install 🤗 LeRobot:
```bash
cd lerobot && pip install ".[feetech]"
cd lerobot && pip install -e ".[feetech]"
```
> [!NOTE]
@@ -145,8 +145,8 @@ Reconnect the usb cable.
#### Update config file
Now that you have your ports, update the **port** default values of [`SO101RobotConfig`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robot_devices/robots/configs.py).
You will find something a class called `so101` where you can update the `port` values with your actual motor ports:
```python
You will find a class called `so101` where you can update the `port` values with your actual motor ports:
```diff
@RobotConfig.register_subclass("so101")
@dataclass
class So101RobotConfig(ManipulatorRobotConfig):
@@ -159,7 +159,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
- port="/dev/tty.usbmodem58760431091",
+ port="{ADD YOUR LEADER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -176,7 +177,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
- port="/dev/tty.usbmodem585A0076891",
+ port="{ADD YOUR FOLLOWER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -308,7 +310,7 @@ Remove all support material from the 3D-printed parts.
##### Wiring
- Attach the motor controller on the back.
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themself and stay in place.
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themselves and stay in place.
<video controls width="640" src="https://github.com/user-attachments/assets/4c2cacfd-9276-4ee4-8bf2-ba2492667b78" type="video/mp4"></video>
@@ -351,7 +353,7 @@ python lerobot/scripts/control_robot.py \
```
## Control your robot
Congrats 🎉, your robot is all set to learn a task on its own. Next we will explain you how to train a neural network to autonomously control a real robot.
Congrats 🎉, your robot is all set to learn a task on its own. Next we will explain to you how to train a neural network to autonomously control a real robot.
**You'll learn to:**
1. How to record and visualize your dataset.
@@ -515,7 +517,7 @@ If you have an additional camera you can add a wrist camera to the SO101. There
## Teleoperate with cameras
We can now teleoperate again while at the same time visualizing the camera's and joint positions with `rerun`.
We can now teleoperate again while at the same time visualizing the cameras and joint positions with `rerun`.
```bash
python lerobot/scripts/control_robot.py \
@@ -526,7 +528,7 @@ python lerobot/scripts/control_robot.py \
## Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
Once you're familiar with teleoperation, you can record your first dataset with SO-101.
We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens).

View File

@@ -168,12 +168,7 @@ available_datasets = sorted(
)
# lists all available policies from `lerobot/common/policies`
available_policies = [
"act",
"diffusion",
"tdmpc",
"vqbet",
]
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
# lists all available robots from `lerobot/common/robot_devices/robots`
available_robots = [

View File

@@ -106,7 +106,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
class AsyncImageWriter:
"""
This class abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
save images on disk asynchronously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it creates a threads pool of size `num_threads`.

View File

@@ -944,7 +944,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def stop_image_writer(self) -> None:
"""
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized.
"""
if self.image_writer is not None:
self.image_writer.stop()

View File

@@ -101,7 +101,7 @@ def decode_video_frames_torchvision(
keyframes_only = False
torchvision.set_video_backend(backend)
if backend == "pyav":
keyframes_only = True # pyav doesnt support accuracte seek
keyframes_only = True # pyav doesn't support accurate seek
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time

View File

@@ -14,12 +14,10 @@
import abc
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import draccus
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -156,122 +154,3 @@ class XarmEnv(EnvConfig):
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}
@dataclass
class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments."""
enabled: bool = False
record_dir: str = "videos"
trajectory_name: str = "trajectory"
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
control_mode: str = "gamepad"
@dataclass
class EnvTransformConfig:
"""Configuration for environment wrappers."""
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
display_cameras: bool = False
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
use_gripper: bool = False
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvTransformConfig] = None
fps: int = 10
name: str = "real_robot"
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
reward_classifier_pretrained_path: Optional[str] = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
def gym_kwargs(self) -> dict:
return {}
@EnvConfig.register_subclass("hil")
@dataclass
class HILEnvConfig(EnvConfig):
"""Configuration for the HIL environment."""
type: str = "hil"
name: str = "PandaPickCube"
task: str = "PandaPickCubeKeyboard-v0"
use_viewer: bool = True
gripper_penalty: float = 0.0
use_gamepad: bool = True
state_dim: int = 18
action_dim: int = 4
fps: int = 100
episode_length: int = 100
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"observation.image": OBS_IMAGE,
"observation.state": OBS_ROBOT,
}
)
################# args from hilserlrobotenv
reward_classifier_pretrained_path: Optional[str] = None
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvTransformConfig] = None
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
############################
@property
def gym_kwargs(self) -> dict:
return {
"use_viewer": self.use_viewer,
"use_gamepad": self.use_gamepad,
"gripper_penalty": self.gripper_penalty,
}

View File

@@ -17,7 +17,7 @@ import importlib
import gymnasium as gym
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
elif env_type == "hil":
return HILEnvConfig(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
@@ -67,8 +65,5 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
env = env_cls(
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase
# https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/vector/vectorize_observation.py#L19,
# env = ObservationProcessorWrapper(env=env)
return env

View File

@@ -47,10 +47,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img)
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
# This is the case for human-in-the-loop RL where there is only one environment.
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
@@ -66,18 +62,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[imgkey] = img
if "environment_state" in observations:
env_state = torch.from_numpy(observations["environment_state"]).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
return_observations["observation.environment_state"] = env_state
return_observations["observation.environment_state"] = torch.from_numpy(
observations["environment_state"]
).float()
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations["observation.state"] = agent_pos
# requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return return_observations

View File

@@ -14,9 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
import draccus
import torch
@@ -45,16 +44,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
return "adam"
@abc.abstractmethod
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
NOTE: Multiple optimizers are useful when you have different models to optimize.
For example, you can have one optimizer for the policy and another one for the value function
in reinforcement learning settings.
Returns:
The optimizer or a dictionary of optimizers.
"""
def build(self) -> torch.optim.Optimizer:
raise NotImplementedError
@@ -104,76 +94,7 @@ class SGDConfig(OptimizerConfig):
return torch.optim.SGD(params, **kwargs)
@OptimizerConfig.register_subclass("multi_adam")
@dataclass
class MultiAdamConfig(OptimizerConfig):
"""Configuration for multiple Adam optimizers with different parameter groups.
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
Args:
lr: Default learning rate (used if not specified for a group)
weight_decay: Default weight decay (used if not specified for a group)
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
grad_clip_norm: Gradient clipping norm
"""
lr: float = 1e-3
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
"""Build multiple Adam optimizers.
Args:
params_dict: Dictionary mapping parameter group names to lists of parameters
The keys should match the keys in optimizer_groups
Returns:
Dictionary mapping parameter group names to their optimizers
"""
optimizers = {}
for name, params in params_dict.items():
# Get group-specific hyperparameters or use defaults
group_config = self.optimizer_groups.get(name, {})
# Create optimizer with merged parameters (defaults + group-specific)
optimizer_kwargs = {
"lr": group_config.get("lr", self.lr),
"betas": group_config.get("betas", (0.9, 0.999)),
"eps": group_config.get("eps", 1e-5),
"weight_decay": group_config.get("weight_decay", self.weight_decay),
}
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
return optimizers
def save_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> None:
"""Save optimizer state to disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to save the optimizer state.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
optimizer_dir.mkdir(exist_ok=True, parents=True)
_save_single_optimizer_state(opt, optimizer_dir)
else:
# Handle single optimizer
_save_single_optimizer_state(optimizer, save_dir)
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
"""Save a single optimizer's state to disk."""
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
state = optimizer.state_dict()
param_groups = state.pop("param_groups")
flat_state = flatten_dict(state)
@@ -181,44 +102,11 @@ def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""Load optimizer state from disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to load the optimizer state from.
Returns:
The updated optimizer(s) with loaded state.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
loaded_optimizers = {}
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
if optimizer_dir.exists():
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
else:
loaded_optimizers[name] = opt
return loaded_optimizers
else:
# Handle single optimizer
return _load_single_optimizer_state(optimizer, save_dir)
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
"""Load a single optimizer's state from disk."""
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
# Handle case where 'state' key might not exist (for newly created optimizers)
if "state" in state:
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
else:
loaded_state_dict = {"state": {}}
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
if "param_groups" in current_state_dict:
param_groups = deserialize_json_into_object(

View File

@@ -15,5 +15,6 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig

View File

@@ -27,7 +27,7 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
@@ -60,14 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "sac":
from lerobot.common.policies.sac.modeling_sac import SACPolicy
elif name == "smolvla":
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SACPolicy
elif name == "reward_classifier":
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
return Classifier
return SmolVLAPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -85,8 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

View File

@@ -151,7 +151,6 @@ class Normalize(nn.Module):
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# TODO: Remove this shallow copy
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
@@ -253,168 +252,3 @@ class Unnormalize(nn.Module):
else:
raise ValueError(norm_mode)
return batch
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
module: nn.Module,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> None:
"""Register statistics buffers (mean/std or min/max) on the given *module*.
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
but is factored out so it can be reused by both classes and stay in sync.
"""
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
shape: tuple[int, ...] = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# reduce spatial dimensions, keep channel dimension only
c, *_ = shape
shape = (c, 1, 1)
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.full(shape, torch.inf, dtype=torch.float32)
std = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"]
std_data = stats[key]["std"]
if isinstance(mean_data, torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_mean", mean)
module.register_buffer(f"{prefix}_std", std)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"]
max_data = stats[key]["max"]
if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_min", min_val)
module.register_buffer(f"{prefix}_max", max_val)
continue
raise ValueError(norm_mode)
class NormalizeBuffer(nn.Module):
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
batch[key] = batch[key] * 2 - 1
continue
raise ValueError(norm_mode)
return batch
class UnnormalizeBuffer(nn.Module):
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max_val - min_val) + min_val
continue
raise ValueError(norm_mode)
return batch

View File

@@ -357,7 +357,7 @@ class PI0Policy(PreTrainedPolicy):
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]

View File

@@ -516,7 +516,7 @@ class PI0FAST(nn.Module):
interpolate_like_pi=self.config.interpolate_like_pi,
)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]

View File

@@ -1,62 +0,0 @@
from dataclasses import dataclass, field
from typing import List
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass(name="reward_classifier")
@dataclass
class RewardClassifierConfig(PreTrainedConfig):
"""Configuration for the Reward Classifier model."""
name: str = "reward_classifier"
num_classes: int = 2
hidden_dim: int = 256
latent_dim: int = 256
image_embedding_pooling_dim: int = 8
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
learning_rate: float = 1e-4
weight_decay: float = 0.01
grad_clip_norm: float = 1.0
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
}
)
@property
def observation_delta_indices(self) -> List | None:
return None
@property
def action_delta_indices(self) -> List | None:
return None
@property
def reward_delta_indices(self) -> List | None:
return None
def get_optimizer_preset(self) -> OptimizerConfig:
return AdamWConfig(
lr=self.learning_rate,
weight_decay=self.weight_decay,
grad_clip_norm=self.grad_clip_norm,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return None
def validate_features(self) -> None:
"""Validate feature configurations."""
has_image = any(key.startswith("observation.image") for key in self.input_features)
if not has_image:
raise ValueError(
"You must provide an image observation (key starting with 'observation.image') in the input features"
)

View File

@@ -1,301 +0,0 @@
import logging
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
from lerobot.common.constants import OBS_IMAGE
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self,
logits: Tensor,
probabilities: Optional[Tensor] = None,
hidden_states: Optional[Tensor] = None,
):
self.logits = logits
self.probabilities = probabilities
self.hidden_states = hidden_states
def __repr__(self):
return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class SpatialLearnedEmbeddings(nn.Module):
def __init__(self, height, width, channel, num_features=8):
"""
PyTorch implementation of learned spatial embeddings
Args:
height: Spatial height of input features
width: Spatial width of input features
channel: Number of input channels
num_features: Number of output embedding dimensions
"""
super().__init__()
self.height = height
self.width = width
self.channel = channel
self.num_features = num_features
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
def forward(self, features):
"""
Forward pass for spatial embedding
Args:
features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch
Returns:
Output tensor of shape [B, C*F] or [C*F] if no batch
"""
features = features.last_hidden_state
original_shape = features.shape
if features.dim() == 3:
features = features.unsqueeze(0) # Add batch dim
features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1]
kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F]
# Element-wise multiplication and spatial reduction
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W
# Reshape to combine channel and feature dimensions
output = output.view(output.size(0), -1) # [B, C*F]
# Remove batch dim
if len(original_shape) == 3:
output = output.squeeze(0)
return output
class Classifier(PreTrainedPolicy):
"""Image classifier built on top of a pre-trained encoder."""
name = "reward_classifier"
config_class = RewardClassifierConfig
def __init__(
self,
config: RewardClassifierConfig,
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
):
from transformers import AutoModel
super().__init__(config)
self.config = config
# Initialize normalization (standardized with the policy framework)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Set up encoder
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
self.encoder = encoder.vision_model
self.vision_config = encoder.config.vision_config
else:
self.encoder = encoder
self.vision_config = getattr(encoder, "config", None)
# Model type from config
self.is_cnn = self.config.model_type == "cnn"
# For CNNs, initialize backbone
if self.is_cnn:
self._setup_cnn_backbone()
self._freeze_encoder()
# Extract image keys from input_features
self.image_keys = [
key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE)
]
if self.is_cnn:
self.encoders = nn.ModuleDict()
for image_key in self.image_keys:
encoder = self._create_single_encoder()
self.encoders[image_key] = encoder
self._build_classifier_head()
def _setup_cnn_backbone(self):
"""Set up CNN encoder"""
if hasattr(self.encoder, "fc"):
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
for param in self.encoder.parameters():
param.requires_grad = False
def _create_single_encoder(self):
encoder = nn.Sequential(
self.encoder,
SpatialLearnedEmbeddings(
height=4,
width=4,
channel=self.feature_dim,
num_features=self.config.image_embedding_pooling_dim,
),
nn.Dropout(self.config.dropout_rate),
nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
nn.Tanh(),
)
return encoder
def _build_classifier_head(self) -> None:
"""Initialize the classifier head architecture."""
# Get input dimension based on model type
if self.is_cnn:
input_dim = self.config.latent_dim
else: # Transformer models
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(
self.config.hidden_dim,
1 if self.config.num_classes == 2 else self.config.num_classes,
),
)
def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""
with torch.no_grad():
if self.is_cnn:
# The HF ResNet applies pooling internally
outputs = self.encoders[image_key](x)
return outputs
else: # Transformer models
outputs = self.encoder(x)
return outputs.last_hidden_state[:, 0, :]
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
"""Extract image tensors and label tensors from batch."""
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
labels = batch["next.reward"]
return images, labels
def predict(self, xs: list) -> ClassifierOutput:
"""Forward pass of the classifier for inference."""
encoder_outputs = torch.hstack(
[self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)]
)
logits = self.classifier_head(encoder_outputs)
if self.config.num_classes == 2:
logits = logits.squeeze(-1)
probabilities = torch.sigmoid(logits)
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
# Normalize inputs if needed
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images and labels
images, labels = self.extract_images_and_labels(batch)
# Get predictions
outputs = self.predict(images)
# Calculate loss
if self.config.num_classes == 2:
# Binary classification
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
else:
# Multi-class classification
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
predictions = torch.argmax(outputs.logits, dim=1)
# Calculate accuracy for logging
correct = (predictions == labels).sum().item()
total = labels.size(0)
accuracy = 100 * correct / total
# Return loss and metrics for logging
output_dict = {
"accuracy": accuracy,
"correct": correct,
"total": total,
}
return loss, output_dict
def predict_reward(self, batch, threshold=0.5):
"""Eval method. Returns predicted reward with the decision threshold as argument."""
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images from batch dict
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
if self.config.num_classes == 2:
probs = self.predict(images).probabilities
logging.debug(f"Predicted reward images: {probs}")
return (probs > threshold).float()
else:
return torch.argmax(self.predict(images).probabilities, dim=1)
def get_optim_params(self):
"""Return optimizer parameters for the policy."""
return self.parameters()
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
raise NotImplementedError("Reward classifiers do not select actions")
def reset(self):
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
pass

View File

@@ -1,243 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
def is_image_feature(key: str) -> bool:
"""Check if a feature key represents an image feature.
Args:
key: The feature key to check
Returns:
True if the key represents an image feature, False otherwise
"""
return key.startswith("observation.image")
@dataclass
class ConcurrencyConfig:
"""Configuration for the concurrency of the actor and learner.
Possible values are:
- "threads": Use threads for the actor and learner.
- "processes": Use processes for the actor and learner.
"""
actor: str = "threads"
learner: str = "threads"
@dataclass
class ActorLearnerConfig:
learner_host: str = "127.0.0.1"
learner_port: int = 50051
policy_parameters_push_frequency: int = 4
@dataclass
class CriticNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
final_activation: str | None = None
@dataclass
class ActorNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
@dataclass
class PolicyConfig:
use_tanh_squash: bool = True
log_std_min: float = 1e-5
log_std_max: float = 10.0
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac")
@dataclass
class SACConfig(PreTrainedConfig):
"""Soft Actor-Critic (SAC) configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
"""
# Mapping of feature types to normalization modes
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Statistics for normalizing different types of inputs
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
"observation.image": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
}
)
# Architecture specifics
# Device to run the model on (e.g., "cuda", "cpu")
device: str = "cpu"
# Device to store the model on
storage_device: str = "cpu"
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
vision_encoder_name: str | None = None
# Whether to freeze the vision encoder during training
freeze_vision_encoder: bool = True
# Hidden dimension size for the image encoder
image_encoder_hidden_dim: int = 32
# Whether to use a shared encoder for actor and critic
shared_encoder: bool = True
# Number of discrete actions, eg for gripper actions
num_discrete_actions: int | None = None
# Dimension of the image embedding pooling
image_embedding_pooling_dim: int = 8
# Training parameter
# Number of steps for online training
online_steps: int = 1000000
# Seed for the online environment
online_env_seed: int = 10000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
offline_buffer_capacity: int = 100000
# Whether to use asynchronous prefetching for the buffers
async_prefetch: bool = False
# Number of steps before learning starts
online_step_before_learning: int = 100
# Frequency of policy updates
policy_update_freq: int = 1
# SAC algorithm parameters
# Discount factor for the SAC algorithm
discount: float = 0.99
# Initial temperature value
temperature_init: float = 1.0
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
utd_ratio: int = 1
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Target entropy for the SAC algorithm
target_entropy: float | None = None
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Network configuration
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for actor-learner architecture
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations
use_torch_compile: bool = True
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration
def get_optimizer_preset(self) -> MultiAdamConfig:
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"actor": {"lr": self.actor_lr},
"critic": {"lr": self.critic_lr},
"temperature": {"lr": self.temperature_lr},
},
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
has_image = any(is_image_feature(key) for key in self.input_features)
has_state = "observation.state" in self.input_features
if not (has_state or has_image):
raise ValueError(
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
)
if "action" not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def image_features(self) -> list[str]:
return [key for key in self.input_features if is_image_feature(key)]
@property
def observation_delta_indices(self) -> list:
return None
@property
def action_delta_indices(self) -> list:
return None # SAC typically predicts one action at a time
@property
def reward_delta_indices(self) -> None:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,154 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@PreTrainedConfig.register_subclass("smolvla")
@dataclass
class SmolVLAConfig(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (512, 512)
# Add empty images. Used by smolvla_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = True
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
attention_mode: str = "cross_attn"
prefix_length: int = -1
pad_language_to: str = "longest" # "max_length"
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
num_vlm_layers: int = 16 # Number of layers used in the VLM (first num_vlm_layers layers)
self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
)
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list:
return [0]
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,801 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SmolVLA:
[Paper](https://huggingface.co/papers/2506.01844)
Designed by Hugging Face.
Install smolvla extra dependencies:
```bash
pip install -e ".[smolvla]"
```
Example of finetuning the smolvla pretrained model (`smolvla_base`):
```bash
python lerobot/scripts/train.py \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
and an action expert.
```bash
python lerobot/scripts/train.py \
--policy.type=smolvla \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
Example of using the smolvla pretrained model outside LeRobot training framework:
```python
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
```
"""
import math
from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoProcessor
from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.policies.normalize import (
Normalize,
Unnormalize,
)
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.common.policies.utils import (
populate_queues,
)
from lerobot.common.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb
def sample_beta(alpha, beta, bsize, device):
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
return gamma1 / (gamma1 + gamma2)
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def resize_with_pad(img, width, height, pad_value=-1):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] == new_dim:
return vector
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
new_vector[..., :current_dim] = vector
return new_vector
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with smolvla which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by smolvla to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class SmolVLAPolicy(PreTrainedPolicy):
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
config_class = SmolVLAConfig
name = "smolvla"
def __init__(
self,
config: SmolVLAConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
self.model = VLAFlowMatching(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
def get_optim_params(self) -> dict:
return self.parameters()
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._queues[ACTION]) == 0:
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
# For backward pass
loss = losses.mean()
# For backward pass
loss_dict["loss"] = loss
return loss, loss_dict
def prepare_images(self, batch):
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
"""
images = []
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
if f"{key}_padding_mask" in batch:
mask = batch[f"{key}_padding_mask"].bool()
else:
mask = torch.ones(bsize, dtype=torch.bool, device=device)
images.append(img)
img_masks.append(mask)
# Create image features not present in the batch
# as fully 0 padded images.
for num_empty_cameras in range(len(missing_img_keys)):
if num_empty_cameras >= self.config.empty_cameras:
break
img = torch.ones_like(img) * -1
mask = torch.zeros_like(mask)
images.append(img)
img_masks.append(mask)
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_ROBOT].device
tasks = batch["task"]
if len(tasks) == 1:
tasks = [tasks[0] for _ in range(batch[OBS_ROBOT].shape[0])]
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding=self.config.pad_language_to,
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
def prepare_state(self, batch):
"""Pad state"""
state = batch[OBS_ROBOT][:, -1, :] if batch[OBS_ROBOT].ndim > 2 else batch[OBS_ROBOT]
state = pad_vector(state, self.config.max_state_dim)
return state
def prepare_action(self, batch):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
def pad_tensor(tensor, max_len, pad_value=0):
"""
Efficiently pads a tensor along sequence dimension to match max_len.
Args:
tensor (torch.Tensor): Shape (B, L, ...) or (B, L).
max_len (int): Fixed sequence length.
pad_value (int/float): Value for padding.
Returns:
torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
"""
b, d = tensor.shape[:2]
# Create a padded tensor of max_len and copy the existing values
padded_tensor = torch.full(
(b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, :d] = tensor # Efficient in-place copy
return padded_tensor
class VLAFlowMatching(nn.Module):
"""
SmolVLA
[Paper]()
Designed by Hugging Face.
┌──────────────────────────────┐
│ actions │
│ ▲ │
│ ┌─────────┐ ┌─|────┐ │
│ | │────► │ │ │
│ | │ kv │ │ │
│ | │────► │Action│ │
│ | VLM │cache │Expert│ |
│ │ │────► | │ │
│ │ │ │ │ │
│ └▲──▲───▲─┘ └───▲──┘ |
│ │ | | │ |
│ | | | noise │
│ │ │ state │
│ │ language tokens │
│ image(s) │
└──────────────────────────────┘
"""
def __init__(self, config):
super().__init__()
self.config = config
self.vlm_with_expert = SmolVLMWithExpertModel(
model_id=self.config.vlm_model_name,
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only,
load_vlm_weights=self.config.load_vlm_weights,
attention_mode=self.config.attention_mode,
num_expert_layers=self.config.num_expert_layers,
num_vlm_layers=self.config.num_vlm_layers,
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(
self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
)
self.action_time_mlp_out = nn.Linear(
self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
)
self.set_requires_grad()
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
self.global_image_start_token = torch.tensor(
[self.fake_image_token, self.global_image_token], dtype=torch.long
)
self.add_image_special_tokens = self.config.add_image_special_tokens
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
self.prefix_length = self.config.prefix_length
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
def sample_noise(self, shape, device):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for SmolVLM transformer processing.
"""
embs = []
pad_masks = []
att_masks = []
for _img_idx, (
img,
img_mask,
) in enumerate(zip(images, img_masks, strict=False)):
if self.add_image_special_tokens:
image_start_token = (
self.vlm_with_expert.embed_language_tokens(
self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
image_start_mask = torch.ones_like(
image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
)
att_masks += [0] * (image_start_mask.shape[-1])
embs.append(image_start_token)
pad_masks.append(image_start_mask)
img_emb = self.vlm_with_expert.embed_image(img)
img_emb = img_emb
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
embs.append(img_emb)
pad_masks.append(img_mask)
att_masks += [0] * (num_img_embs)
if self.add_image_special_tokens:
image_end_token = (
self.vlm_with_expert.embed_language_tokens(
self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
image_end_mask = torch.ones_like(
image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
)
embs.append(image_end_token)
pad_masks.append(image_end_mask)
att_masks += [0] * (image_end_mask.shape[1])
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
# Normalize language embeddings
lang_emb_dim = lang_emb.shape[-1]
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
embs.append(lang_emb)
pad_masks.append(lang_masks)
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
state_emb = self.state_proj(state)
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
embs.append(state_emb)
bsize = state_emb.shape[0]
device = state_emb.device
states_seq_len = state_emb.shape[1]
state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1] * (states_seq_len)
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
att_masks = att_masks[None, :]
seq_len = pad_masks.shape[1]
if seq_len < self.prefix_length:
embs = pad_tensor(embs, self.prefix_length, pad_value=0)
pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0)
att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0)
att_masks = att_masks.expand(bsize, -1)
return embs, pad_masks, att_masks
def embed_suffix(self, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
# Fuse timestep + action information using an MLP
action_emb = self.action_in_proj(noisy_actions)
device = action_emb.device
bsize = action_emb.shape[0]
dtype = action_emb.dtype
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep,
self.vlm_with_expert.expert_hidden_size,
self.config.min_period,
self.config.max_period,
device=device,
)
time_emb = time_emb.type(dtype=dtype)
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
action_time_emb = F.silu(action_time_emb) # swish == silu
action_time_emb = self.action_time_mlp_out(action_time_emb)
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] * self.config.chunk_size
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
(_, suffix_out), _ = self.vlm_with_expert.forward(
attention_mask=att_2d_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
fill_kv_cache=False,
)
suffix_out = suffix_out[:, -self.config.chunk_size :]
# Original openpi code, upcast attention output
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
if noise is None:
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
_, past_key_values = self.vlm_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step
x_t += dt * v_t
time += dt
return x_t
def denoise_step(
self,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
outputs_embeds, _ = self.vlm_with_expert.forward(
attention_mask=full_att_2d_masks,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=self.config.use_cache,
fill_kv_cache=False,
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
return v_t

View File

@@ -0,0 +1,550 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import List, Optional
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForImageTextToText,
AutoProcessor,
SmolVLMForConditionalGeneration,
)
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
class SmolVLMWithExpertModel(nn.Module):
def __init__(
self,
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
load_vlm_weights: bool = True,
train_expert_only: bool = True,
freeze_vision_encoder: bool = False,
attention_mode: str = "self_attn",
num_expert_layers: int = -1,
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
):
super().__init__()
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)
config = self.vlm.config
else:
config = AutoConfig.from_pretrained(model_id)
self.vlm = SmolVLMForConditionalGeneration(config=config)
self.processor = AutoProcessor.from_pretrained(model_id)
if num_vlm_layers > 0:
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
self.config = config
# Smaller lm expert
lm_expert_config = copy.deepcopy(config.text_config)
hidden_size = lm_expert_config.hidden_size
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
lm_expert_config.num_hidden_layers = self.num_vlm_layers
if num_expert_layers > 0:
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
)
lm_expert_config.num_hidden_layers = num_expert_layers
self.lm_expert = AutoModel.from_config(lm_expert_config)
self.num_expert_layers = len(self.lm_expert.layers)
self.self_attn_every_n_layers = self_attn_every_n_layers
if "cross" in attention_mode:
# Reshape qkv projections to have the same input dimension as the vlm
for layer_idx in range(len(self.lm_expert.layers)):
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
continue
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
# Remove unused embed_tokens
self.lm_expert.embed_tokens = None
self.num_attention_heads = self.config.text_config.num_attention_heads
self.num_key_value_heads = self.config.text_config.num_key_value_heads
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_mode = attention_mode
self.expert_hidden_size = lm_expert_config.hidden_size
self.set_requires_grad()
def get_vlm_model(self):
return self.vlm.model
def set_requires_grad(self):
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
for params in self.get_vlm_model().vision_model.parameters():
params.requires_grad = False
if self.train_expert_only:
self.vlm.eval()
for params in self.vlm.parameters():
params.requires_grad = False
else:
# To avoid unused params issue with distributed training
last_layers = [self.num_vlm_layers - 1]
if (
self.num_vlm_layers != self.num_expert_layers
and self.num_vlm_layers % self.num_expert_layers == 0
):
last_layers.append(self.num_vlm_layers - 2)
frozen_layers = [
"lm_head",
"text_model.model.norm.weight",
]
for layer in last_layers:
frozen_layers.append(f"text_model.model.layers.{layer}.")
for name, params in self.vlm.named_parameters():
if any(k in name for k in frozen_layers):
params.requires_grad = False
# To avoid unused params issue with distributed training
for name, params in self.lm_expert.named_parameters():
if "lm_head" in name:
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
if self.train_expert_only:
self.vlm.eval()
def embed_image(self, image: torch.Tensor):
patch_attention_mask = None
# Get sequence from the vision encoder
image_hidden_states = (
self.get_vlm_model()
.vision_model(
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
patch_attention_mask=patch_attention_mask,
)
.last_hidden_state
)
# Modality projection & resampling
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
return image_hidden_states
def embed_language_tokens(self, tokens: torch.Tensor):
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
def forward_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
if hidden_states is None or layer is None:
continue
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
seq_len = query_states.shape[1]
if seq_len < position_ids.shape[1]:
_position_ids = position_ids[:, :seq_len]
_attention_mask = attention_mask[:, :seq_len, :seq_len]
else:
_position_ids = position_ids
_attention_mask = attention_mask
attention_mask_ = _attention_mask
position_ids_ = _position_ids
query_states = apply_rope(query_states, position_ids_)
key_states = apply_rope(key_states, position_ids_)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
)
return [att_output], past_key_values
def forward_cross_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
attention_interface = self.get_attention_interface()
att_outputs = []
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
)
if len(inputs_embeds) == 2 and not past_key_values:
# Prefix attention
seq_len = inputs_embeds[0].shape[1]
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
layer = model_layers[0][layer_idx]
hidden_states = layer.input_layernorm(inputs_embeds[0])
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
# B,L,H,D with L sequence length, H number of heads, D head dim
query_states = apply_rope(query_state, position_id)
key_states = apply_rope(key_state, position_id)
att_output = attention_interface(
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_outputs.append(att_output)
else:
expert_position_id = position_ids
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = past_key_values[layer_idx]["key_states"]
value_states = past_key_values[layer_idx]["value_states"]
# Expert
expert_layer = model_layers[1][layer_idx]
if expert_layer is not None:
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
expert_input_shape = expert_hidden_states.shape[:-1]
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
*key_states.shape[:2], -1
)
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
) # k_proj should have same dim as kv
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
*value_states.shape[:2], -1
)
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
)
expert_position_id = (
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
) # start from 0
expert_attention_mask = attention_mask[
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
] # take into account kv
expert_query_states = apply_rope(expert_query_state, expert_position_id)
att_output = attention_interface(
expert_attention_mask,
batch_size,
head_dim,
expert_query_states,
expert_key_states,
expert_value_states,
)
att_outputs.append(att_output)
else:
att_outputs.append(None)
# att_output = att_output.to(dtype=models[i].dtype)
return att_outputs, past_key_values
def get_model_layers(self, models: list) -> list:
vlm_layers = []
expert_layers = []
multiple_of = self.num_vlm_layers // self.num_expert_layers
for i in range(self.num_vlm_layers):
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
expert_layer = None
else:
expert_layer_index = i // multiple_of if multiple_of > 0 else i
expert_layer = models[1].layers[expert_layer_index]
vlm_layers.append(models[0].layers[i])
expert_layers.append(expert_layer)
return [vlm_layers, expert_layers]
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: List[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
fill_kv_cache: Optional[bool] = None,
):
models = [self.get_vlm_model().text_model, self.lm_expert]
model_layers = self.get_model_layers(models)
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.num_vlm_layers
head_dim = self.vlm.config.text_config.head_dim
for layer_idx in range(num_layers):
if (
fill_kv_cache
or "cross" not in self.attention_mode
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
):
att_outputs, past_key_values = self.forward_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
else:
att_outputs, past_key_values = self.forward_cross_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
att_output = (
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
) # in case of self_attn
if hidden_states is not None:
if layer is None:
outputs_embeds.append(hidden_states)
continue
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
att_out = att_output[:, start:end]
out_emb = layer.self_attn.o_proj(att_out)
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end if len(att_outputs) == 1 else 0
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
attention_interface = self.eager_attention_forward
return attention_interface
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.num_attention_heads
num_key_value_heads = self.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
att_weights = att_weights.to(dtype=torch.float32)
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output

View File

@@ -87,8 +87,6 @@ class RecordControlConfig(ControlConfig):
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Reset follower arms to an initial position.
reset_follower_arms: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.

View File

@@ -24,7 +24,6 @@ from contextlib import nullcontext
from copy import copy
from functools import cache
import numpy as np
import rerun as rr
import torch
from deepdiff import DeepDiff
@@ -59,7 +58,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
log_dt("dt", dt_s)
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.startswith("stretch"):
if not robot.robot_type.startswith(("stretch", "realman")):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
@@ -110,6 +109,10 @@ def predict_action(observation, policy, device, use_amp):
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
# Skip all observations that are not tensors (e.g. text)
if not isinstance(observation[name], torch.Tensor):
continue
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
@@ -130,11 +133,9 @@ def predict_action(observation, policy, device, use_amp):
def init_keyboard_listener():
"""
Initializes a keyboard listener to enable early termination of an episode
or environment reset by pressing the right arrow key ('->'). This may require
sudo permissions to allow the terminal to monitor keyboard events.
"""
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
@@ -163,7 +164,6 @@ def init_keyboard_listener():
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Error handling key press: {e}")
@@ -247,6 +247,11 @@ def control_loop(
timestamp = 0
start_episode_t = time.perf_counter()
# Controls starts, if policy is given it needs cleaning up
if policy is not None:
policy.reset()
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
@@ -254,13 +259,12 @@ def control_loop(
observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
action = None
observation["task"] = [single_task]
observation["robot_type"] = [policy.robot_type] if hasattr(policy, "robot_type") else [""]
if policy is not None:
pred_action = predict_action(
observation,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
@@ -268,14 +272,16 @@ def control_loop(
action = {"action": action}
if dataset is not None:
observation = {k: v for k, v in observation.items() if k not in ["task", "robot_type"]}
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
for k, v in action.items():
for i, vv in enumerate(v):
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
if action is not None:
for k, v in action.items():
for i, vv in enumerate(v):
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
@@ -308,17 +314,7 @@ def reset_environment(robot, events, reset_time_s, fps):
)
def reset_follower_position(robot_arm, target_position):
current_position = robot_arm.read("Present_Position")
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an arbitrary number
for pose in trajectory:
robot_arm.write("Goal_Position", pose)
busy_wait(0.015)
def stop_recording(robot, listener, display_cameras):
def stop_recording(robot, listener, display_data):
robot.disconnect()
if not is_headless() and listener is not None:
@@ -344,20 +340,12 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset,
robot: Robot,
fps: int,
use_videos: bool,
extra_features: dict = None,
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
) -> None:
features_from_robot = get_features_from_robot(robot, use_videos)
if extra_features is not None:
features_from_robot.update(extra_features)
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, features_from_robot),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
]
mismatches = []

View File

@@ -39,3 +39,12 @@ class FeetechMotorsBusConfig(MotorsBusConfig):
port: str
motors: dict[str, tuple[int, str]]
mock: bool = False
@MotorsBusConfig.register_subclass("realman")
@dataclass
class RealmanMotorsBusConfig(MotorsBusConfig):
ip: str
port: int
motors: dict[str, tuple[int, str]]
init_joint: dict[str, list]

View File

@@ -0,0 +1,150 @@
import time
from typing import Dict
from lerobot.common.robot_devices.motors.configs import RealmanMotorsBusConfig
from Robotic_Arm.rm_robot_interface import *
class RealmanMotorsBus:
"""
对Realman SDK的二次封装
"""
def __init__(self,
config: RealmanMotorsBusConfig):
self.rmarm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
self.handle = self.rmarm.rm_create_robot_arm(config.ip, config.port)
self.motors = config.motors
self.init_joint_position = config.init_joint['joint'] # [6 joints + 1 gripper]
self.safe_disable_position = config.init_joint['joint']
self.rmarm.rm_movej(self.init_joint_position[:-1], 5, 0, 0, 1)
time.sleep(3)
ret = self.rmarm.rm_get_current_arm_state()
self.init_pose = ret[1]['pose']
@property
def motor_names(self) -> list[str]:
return list(self.motors.keys())
@property
def motor_models(self) -> list[str]:
return [model for _, model in self.motors.values()]
@property
def motor_indices(self) -> list[int]:
return [idx for idx, _ in self.motors.values()]
def connect(self, enable=True) -> bool:
'''
使能机械臂并检测使能状态,尝试5s,如果使能超时则退出程序
'''
enable_flag = False
loop_flag = False
# 设置超时时间(秒)
timeout = 5
# 记录进入循环前的时间
start_time = time.time()
elapsed_time_flag = False
while not loop_flag:
elapsed_time = time.time() - start_time
print("--------------------")
if enable:
# 获取机械臂状态
ret = self.rmarm.rm_get_current_arm_state()
if ret[0] == 0: # 成功获取状态
enable_flag = True
else:
enable_flag = False
# 断开所有连接,销毁线程
RoboticArm.rm_destory()
print("使能状态:", enable_flag)
print("--------------------")
if(enable_flag == enable):
loop_flag = True
enable_flag = True
else:
loop_flag = False
enable_flag = False
# 检查是否超过超时时间
if elapsed_time > timeout:
print("超时....")
elapsed_time_flag = True
enable_flag = True
break
time.sleep(1)
resp = enable_flag
print(f"Returning response: {resp}")
return resp
def motor_names(self):
return
def set_calibration(self):
return
def revert_calibration(self):
return
def apply_calibration(self):
"""
移动到初始位置
"""
self.write(target_joint=self.init_joint_position)
def write(self, target_joint:list):
# self.rmarm.rm_movej(target_joint[:-1], 50, 0, 0, 1)
self.rmarm.rm_movej_follow(target_joint[:-1])
self.rmarm.rm_set_gripper_position(target_joint[-1], block=False, timeout=2)
def write_endpose(self, target_endpose: list, gripper: int):
self.rmarm.rm_movej_p(target_endpose, 50, 0, 0, 1)
self.rmarm.rm_set_gripper_position(gripper, block=False, timeout=2)
def write_joint_slow(self, target_joint: list):
self.rmarm.rm_movej(target_joint, 5, 0, 0, 0)
def write_joint_canfd(self, target_joint: list):
self.rmarm.rm_movej_canfd(target_joint, False)
def write_endpose_canfd(self, target_pose: list):
self.rmarm.rm_movep_canfd(target_pose, False)
def write_gripper(self, gripper: int):
self.rmarm.rm_set_gripper_position(gripper, False, 2)
def read(self) -> Dict:
"""
- 机械臂关节消息,单位1度;[-1, 1]
- 机械臂夹爪消息,[-1, 1]
"""
joint_msg = self.rmarm.rm_get_current_arm_state()[1]
joint_state = joint_msg['joint']
gripper_msg = self.rmarm.rm_get_gripper_state()[1]
gripper_state = gripper_msg['actpos']
return {
"joint_1": joint_state[0]/180,
"joint_2": joint_state[1]/180,
"joint_3": joint_state[2]/180,
"joint_4": joint_state[3]/180,
"joint_5": joint_state[4]/180,
"joint_6": joint_state[5]/180,
"gripper": (gripper_state-500)/500
}
def read_current_arm_joint_state(self):
return self.rmarm.rm_get_current_arm_state()[1]['joint']
def read_current_arm_endpose_state(self):
return self.rmarm.rm_get_current_arm_state()[1]['pose']
def safe_disconnect(self):
"""
Move to safe disconnect position
"""
self.write(target_joint=self.safe_disable_position)
# 断开所有连接,销毁线程
RoboticArm.rm_destory()

View File

@@ -44,6 +44,11 @@ def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig
motors_buses[key] = FeetechMotorsBus(cfg)
elif cfg.type == "realman":
from lerobot.common.robot_devices.motors.realman import RealmanMotorsBus
motors_buses[key] = RealmanMotorsBus(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
@@ -65,3 +70,7 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
else:
raise ValueError(f"The motor type '{motor_type}' is not valid.")
def get_motor_names(arm: dict[str, MotorsBus]) -> list:
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]

View File

@@ -27,6 +27,7 @@ from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig,
FeetechMotorsBusConfig,
MotorsBusConfig,
RealmanMotorsBusConfig
)
@@ -674,3 +675,91 @@ class LeKiwiRobotConfig(RobotConfig):
)
mock: bool = False
@RobotConfig.register_subclass("realman")
@dataclass
class RealmanRobotConfig(RobotConfig):
inference_time: bool = False
max_gripper: int = 990
min_gripper: int = 10
servo_config_file: str = "/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_arm.yaml"
left_follower_arm: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": RealmanMotorsBusConfig(
ip = "192.168.3.18",
port = 8080,
motors={
# name: (index, model)
"joint_1": [1, "realman"],
"joint_2": [2, "realman"],
"joint_3": [3, "realman"],
"joint_4": [4, "realman"],
"joint_5": [5, "realman"],
"joint_6": [6, "realman"],
"gripper": [7, "realman"],
},
init_joint = {'joint': [-90, 90, 90, 90, 90, -90, 10]}
)
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
# "one": OpenCVCameraConfig(
# camera_index=4,
# fps=30,
# width=640,
# height=480,
# ),
"left": IntelRealSenseCameraConfig(
serial_number="153122077516",
fps=30,
width=640,
height=480,
use_depth=False
),
# "right": IntelRealSenseCameraConfig(
# serial_number="405622075165",
# fps=30,
# width=640,
# height=480,
# use_depth=False
# ),
"front": IntelRealSenseCameraConfig(
serial_number="145422072751",
fps=30,
width=640,
height=480,
use_depth=False
),
"high": IntelRealSenseCameraConfig(
serial_number="145422072193",
fps=30,
width=640,
height=480,
use_depth=False
),
}
)
# right_follower_arm: dict[str, MotorsBusConfig] = field(
# default_factory=lambda: {
# "main": RealmanMotorsBusConfig(
# ip = "192.168.3.19",
# port = 8080,
# motors={
# # name: (index, model)
# "joint_1": [1, "realman"],
# "joint_2": [2, "realman"],
# "joint_3": [3, "realman"],
# "joint_4": [4, "realman"],
# "joint_5": [5, "realman"],
# "joint_6": [6, "realman"],
# "gripper": (7, "realman"),
# },
# )
# }
# )

View File

@@ -0,0 +1,292 @@
"""
Teleoperation Realman with a PS5 controller and
"""
import time
import torch
import numpy as np
from dataclasses import dataclass, field, replace
from collections import deque
from lerobot.common.robot_devices.teleop.gamepad import HybridController
from lerobot.common.robot_devices.motors.utils import get_motor_names, make_motors_buses_from_configs
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.robots.configs import RealmanRobotConfig
class RealmanRobot:
def __init__(self, config: RealmanRobotConfig | None = None, **kwargs):
if config is None:
config = RealmanRobotConfig()
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
self.robot_type = self.config.type
self.inference_time = self.config.inference_time # if it is inference time
# build cameras
self.cameras = make_cameras_from_configs(self.config.cameras)
# build realman motors
self.piper_motors = make_motors_buses_from_configs(self.config.left_follower_arm)
self.arm = self.piper_motors['main']
# build init teleop info
self.init_info = {
'init_joint': self.arm.init_joint_position,
'init_pose': self.arm.init_pose,
'max_gripper': config.max_gripper,
'min_gripper': config.min_gripper,
'servo_config_file': config.servo_config_file
}
# build state-action cache
self.joint_queue = deque(maxlen=2)
self.last_endpose = self.arm.init_pose
# build gamepad teleop
if not self.inference_time:
self.teleop = HybridController(self.init_info)
else:
self.teleop = None
self.logs = {}
self.is_connected = False
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
action_names = get_motor_names(self.piper_motors)
state_names = get_motor_names(self.piper_motors)
return {
"action": {
"dtype": "float32",
"shape": (len(action_names),),
"names": action_names,
},
"observation.state": {
"dtype": "float32",
"shape": (len(state_names),),
"names": state_names,
},
}
@property
def has_camera(self):
return len(self.cameras) > 0
@property
def num_cameras(self):
return len(self.cameras)
def connect(self) -> None:
"""Connect RealmanArm and cameras"""
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
"RealmanArm is already connected. Do not run `robot.connect()` twice."
)
# connect RealmanArm
self.arm.connect(enable=True)
print("RealmanArm conneted")
# connect cameras
for name in self.cameras:
self.cameras[name].connect()
self.is_connected = self.is_connected and self.cameras[name].is_connected
print(f"camera {name} conneted")
print("All connected")
self.is_connected = True
self.run_calibration()
def disconnect(self) -> None:
"""move to home position, disenable piper and cameras"""
# move piper to home position, disable
if not self.inference_time:
self.teleop.stop()
# disconnect piper
self.arm.safe_disconnect()
print("RealmanArm disable after 5 seconds")
time.sleep(5)
self.arm.connect(enable=False)
# disconnect cameras
if len(self.cameras) > 0:
for cam in self.cameras.values():
cam.disconnect()
self.is_connected = False
def run_calibration(self):
"""move piper to the home position"""
if not self.is_connected:
raise ConnectionError()
self.arm.apply_calibration()
if not self.inference_time:
self.teleop.reset()
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise ConnectionError()
if self.teleop is None and self.inference_time:
self.teleop = HybridController(self.init_info)
# read target pose state as
before_read_t = time.perf_counter()
state = self.arm.read() # read current joint position from robot
action = self.teleop.get_action() # target joint position and pose end pos from gamepad
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if action['control_mode'] == 'joint':
# 关节控制模式(主模式)
current_pose = self.arm.read_current_arm_endpose_state()
self.teleop.update_endpose_state(current_pose)
target_joints = action['joint_angles'][:-1]
self.arm.write_gripper(action['gripper'])
print(action['gripper'])
if action['master_controller_status']['infrared'] == 1:
if action['master_controller_status']['button'] == 1:
self.arm.write_joint_canfd(target_joints)
else:
self.arm.write_joint_slow(target_joints)
# do action
before_write_t = time.perf_counter()
self.joint_queue.append(list(self.arm.read().values()))
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
else:
target_pose = list(action['end_pose'])
# do action
before_write_t = time.perf_counter()
if self.last_endpose != target_pose:
self.arm.write_endpose_canfd(target_pose)
self.last_endpose = target_pose
self.arm.write_gripper(action['gripper'])
target_joints = self.arm.read_current_arm_joint_state()
self.joint_queue.append(list(self.arm.read().values()))
self.teleop.update_joint_state(target_joints)
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
if not record_data:
return
state = torch.as_tensor(list(self.joint_queue[0]), dtype=torch.float32)
action = torch.as_tensor(list(self.joint_queue[-1]), dtype=torch.float32)
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
action_dict["action"] = action
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict, action_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor:
"""Write the predicted actions from policy to the motors"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"Piper is not connected. You need to run `robot.connect()`."
)
# send to motors, torch to list
target_joints = action.tolist()
len_joint = len(target_joints) - 1
target_joints = [target_joints[i]*180 for i in range(len_joint)] + [target_joints[-1]]
target_joints[-1] = int(target_joints[-1]*500 + 500)
self.arm.write(target_joints)
return action
def capture_observation(self) -> dict:
"""capture current images and joint positions"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"Piper is not connected. You need to run `robot.connect()`."
)
# read current joint positions
before_read_t = time.perf_counter()
state = self.arm.read() # 6 joints + 1 gripper
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
state = torch.as_tensor(list(state.values()), dtype=torch.float32)
# read images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries and format to pytorch
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict
def teleop_safety_stop(self):
""" move to home position after record one episode """
self.run_calibration()
def __del__(self):
if self.is_connected:
self.disconnect()
if not self.inference_time:
self.teleop.stop()
if __name__ == '__main__':
robot = RealmanRobot()
robot.connect()
# robot.run_calibration()
while True:
robot.teleop_step(record_data=True)
robot.capture_observation()
dummy_action = torch.Tensor([-0.40586111280653214, 0.5522833506266276, 0.4998166826036241, -0.3539944542778863, -0.524433347913954, 0.9064999898274739, 0.482])
robot.send_action(dummy_action)
robot.disconnect()
print('ok')

View File

@@ -25,6 +25,7 @@ from lerobot.common.robot_devices.robots.configs import (
So100RobotConfig,
So101RobotConfig,
StretchRobotConfig,
RealmanRobotConfig
)
@@ -65,6 +66,9 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
return StretchRobotConfig(**kwargs)
elif robot_type == "lekiwi":
return LeKiwiRobotConfig(**kwargs)
elif robot_type == 'realman':
return RealmanRobotConfig(**kwargs)
else:
raise ValueError(f"Robot type '{robot_type}' is not available.")
@@ -78,6 +82,12 @@ def make_robot_from_config(config: RobotConfig):
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
return MobileManipulator(config)
elif isinstance(config, RealmanRobotConfig):
from lerobot.common.robot_devices.robots.realman import RealmanRobot
return RealmanRobot(config)
else:
from lerobot.common.robot_devices.robots.stretch import StretchRobot

View File

@@ -0,0 +1,466 @@
import pygame
import threading
import time
import serial
import binascii
import logging
import yaml
from typing import Dict
from Robotic_Arm.rm_robot_interface import *
class ServoArm:
def __init__(self, config_file="config.yaml"):
"""初始化机械臂的串口连接并发送初始数据。
Args:
config_file (str): 配置文件的路径。
"""
self.config = self._load_config(config_file)
self.port = self.config["port"]
self.baudrate = self.config["baudrate"]
self.joint_hex_data = self.config["joint_hex_data"]
self.control_hex_data = self.config["control_hex_data"]
self.arm_axis = self.config.get("arm_axis", 7)
try:
self.serial_conn = serial.Serial(self.port, self.baudrate, timeout=0)
self.bytes_to_send = binascii.unhexlify(self.joint_hex_data.replace(" ", ""))
self.serial_conn.write(self.bytes_to_send)
time.sleep(1)
self.connected = True
logging.info(f"串口连接成功: {self.port}")
except Exception as e:
logging.error(f"串口连接失败: {e}")
self.connected = False
def _load_config(self, config_file):
"""加载配置文件。
Args:
config_file (str): 配置文件的路径。
Returns:
dict: 配置文件内容。
"""
try:
with open(config_file, "r") as file:
config = yaml.safe_load(file)
return config
except Exception as e:
logging.error(f"配置文件加载失败: {e}")
# 返回默认配置
return {
"port": "/dev/ttyUSB0",
"baudrate": 460800,
"joint_hex_data": "55 AA 02 00 00 67",
"control_hex_data": "55 AA 08 00 00 B9",
"arm_axis": 6
}
def _bytes_to_signed_int(self, byte_data):
"""将字节数据转换为有符号整数。
Args:
byte_data (bytes): 字节数据。
Returns:
int: 有符号整数。
"""
return int.from_bytes(byte_data, byteorder="little", signed=True)
def _parse_joint_data(self, hex_received):
"""解析接收到的十六进制数据并提取关节数据。
Args:
hex_received (str): 接收到的十六进制字符串数据。
Returns:
dict: 解析后的关节数据。
"""
logging.debug(f"hex_received: {hex_received}")
joints = {}
for i in range(self.arm_axis):
start = 14 + i * 10
end = start + 8
joint_hex = hex_received[start:end]
joint_byte_data = bytearray.fromhex(joint_hex)
joint_value = self._bytes_to_signed_int(joint_byte_data) / 10000.0
joints[f"joint_{i+1}"] = joint_value
grasp_start = 14 + self.arm_axis*10
grasp_hex = hex_received[grasp_start:grasp_start+8]
grasp_byte_data = bytearray.fromhex(grasp_hex)
# 夹爪进行归一化处理
grasp_value = self._bytes_to_signed_int(grasp_byte_data)/1000
joints["grasp"] = grasp_value
return joints
def _parse_controller_data(self, hex_received):
status = {
'infrared': 0,
'button': 0
}
if len(hex_received) == 18:
status['infrared'] = self._bytes_to_signed_int(bytearray.fromhex(hex_received[12:14]))
status['button'] = self._bytes_to_signed_int(bytearray.fromhex(hex_received[14:16]))
# print(infrared)
return status
def get_joint_actions(self):
"""从串口读取数据并解析关节动作。
Returns:
dict: 包含关节数据的字典。
"""
if not self.connected:
return {}
try:
self.serial_conn.write(self.bytes_to_send)
time.sleep(0.02)
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
if len(bytes_received) == 0:
return {}
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
actions = self._parse_joint_data(hex_received)
return actions
except Exception as e:
logging.error(f"读取串口数据错误: {e}")
return {}
def get_controller_status(self):
bytes_to_send = binascii.unhexlify(self.control_hex_data.replace(" ", ""))
self.serial_conn.write(bytes_to_send)
time.sleep(0.02)
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
# print("control status:", hex_received)
status = self._parse_controller_data(hex_received)
return status
def close(self):
"""关闭串口连接"""
if self.connected and hasattr(self, 'serial_conn'):
self.serial_conn.close()
self.connected = False
logging.info("串口连接已关闭")
class HybridController:
def __init__(self, init_info):
# 初始化pygame和手柄
pygame.init()
pygame.joystick.init()
# 检查是否有连接的手柄
if pygame.joystick.get_count() == 0:
raise Exception("未检测到手柄")
# 初始化手柄
self.joystick = pygame.joystick.Joystick(0)
self.joystick.init()
# 摇杆死区
self.deadzone = 0.15
# 控制模式: True为关节控制主模式False为末端控制
self.joint_control_mode = True
# 精细控制模式
self.fine_control_mode = False
# 初始化末端姿态和关节角度
self.init_joint = init_info['init_joint']
self.init_pose = init_info.get('init_pose', [0]*6)
self.max_gripper = init_info['max_gripper']
self.min_gripper = init_info['min_gripper']
servo_config_file = init_info['servo_config_file']
self.joint = self.init_joint.copy()
self.pose = self.init_pose.copy()
self.pose_speeds = [0.0] * 6
self.joint_speeds = [0.0] * 6
self.tozero = False
# 主臂关节状态
self.master_joint_actions = {}
self.master_controller_status = {}
self.use_master_arm = False
# 末端位姿限制
self.pose_limits = [
(-0.800, 0.800), # X (m)
(-0.800, 0.800), # Y (m)
(-0.800, 0.800), # Z (m)
(-3.14, 3.14), # RX (rad)
(-3.14, 3.14), # RY (rad)
(-3.14, 3.14) # RZ (rad)
]
# 关节角度限制 (度)
self.joint_limits = [
(-180, 180), # joint 1
(-180, 180), # joint 2
(-180, 180), # joint 3
(-180, 180), # joint 4
(-180, 180), # joint 5
(-180, 180) # joint 6
]
# 控制参数
self.linear_step = 0.002 # 线性移动步长(m)
self.angular_step = 0.01 # 角度步长(rad)
# 夹爪状态和速度
self.gripper_speed = 10
self.gripper = self.min_gripper
# 初始化串口通信(主臂关节状态获取)
self.servo_arm = None
if servo_config_file:
try:
self.servo_arm = ServoArm(servo_config_file)
self.use_master_arm = True
logging.info("串口主臂连接成功,启用主从控制模式")
except Exception as e:
logging.error(f"串口主臂连接失败: {e}")
self.use_master_arm = False
# 启动更新线程
self.running = True
self.thread = threading.Thread(target=self.update_controller)
self.thread.start()
print("混合控制器已启动")
print("主控制模式: 关节控制")
if self.use_master_arm:
print("主从控制: 启用")
print("Back按钮: 切换控制模式(关节/末端)")
print("L3按钮: 切换精细控制模式")
print("Start按钮: 重置到初始位置")
def _apply_nonlinear_mapping(self, value):
"""应用非线性映射以提高控制精度"""
sign = 1 if value >= 0 else -1
return sign * (abs(value) ** 2)
def _normalize_angle(self, angle):
"""将角度归一化到[-π, π]范围内"""
import math
while angle > math.pi:
angle -= 2 * math.pi
while angle < -math.pi:
angle += 2 * math.pi
return angle
def update_controller(self):
while self.running:
try:
pygame.event.pump()
except Exception as e:
print(f"控制器错误: {e}")
self.stop()
continue
# 检查控制模式切换 (Back按钮)
if self.joystick.get_button(6): # Back按钮
self.joint_control_mode = not self.joint_control_mode
mode_str = "关节控制" if self.joint_control_mode else "末端位姿控制"
print(f"切换到{mode_str}模式")
time.sleep(0.3) # 防止多次触发
# 检查精细控制模式切换 (L3按钮)
if self.joystick.get_button(10): # L3按钮
self.fine_control_mode = not self.fine_control_mode
print(f"切换到{'精细' if self.fine_control_mode else '普通'}控制模式")
time.sleep(0.3) # 防止多次触发
# 检查重置按钮 (Start按钮)
if self.joystick.get_button(7): # Start按钮
print("重置机械臂到初始位置...")
# print("init_joint", self.init_joint.copy())
self.tozero = True
self.joint = self.init_joint.copy()
self.pose = self.init_pose.copy()
self.pose_speeds = [0.0] * 6
self.joint_speeds = [0.0] * 6
self.gripper_speed = 10
self.gripper = self.min_gripper
print("机械臂已重置到初始位置")
time.sleep(0.3) # 防止多次触发
# 从串口获取主臂关节状态
if self.servo_arm and self.servo_arm.connected:
try:
self.master_joint_actions = self.servo_arm.get_joint_actions()
self.master_controller_status = self.servo_arm.get_controller_status()
if self.master_joint_actions:
logging.debug(f"主臂关节状态: {self.master_joint_actions}")
except Exception as e:
logging.error(f"获取主臂状态错误: {e}")
self.master_joint_actions = {}
# print(self.master_joint_actions)
# 根据控制模式更新相应的控制逻辑
if self.joint_control_mode:
# 关节控制模式下优先使用主臂数据Xbox作为辅助
self.update_joint_control()
else:
# 末端控制模式使用Xbox控制
self.update_end_pose()
time.sleep(0.02)
# print('gripper:', self.gripper)
def update_joint_control(self):
"""更新关节角度控制 - 优先使用主臂数据"""
if self.use_master_arm and self.master_joint_actions:
# 主从控制模式:直接使用主臂的关节角度
try:
# 将主臂关节角度映射到从臂
for i in range(6): # 假设只有6个关节需要控制
joint_key = f"joint_{i+1}"
if joint_key in self.master_joint_actions:
# 直接使用主臂的关节角度(已经是度数)
self.joint[i] = self.master_joint_actions[joint_key]
# 应用关节限制
min_val, max_val = self.joint_limits[i]
self.joint[i] = max(min_val, min(max_val, self.joint[i]))
# print(self.joint)
logging.debug(f"主臂关节映射到从臂: {self.joint[:6]}")
except Exception as e:
logging.error(f"主臂数据映射错误: {e}")
# 如果有主臂夹爪数据,使用主臂夹爪状态
if self.use_master_arm and "grasp" in self.master_joint_actions:
self.gripper = self.master_joint_actions["grasp"] * 1000
self.joint[-1] = self.gripper
def update_end_pose(self):
"""更新末端位姿控制"""
# 根据控制模式调整步长
current_linear_step = self.linear_step * (0.1 if self.fine_control_mode else 1.0)
current_angular_step = self.angular_step * (0.1 if self.fine_control_mode else 1.0)
# 方向键控制XY
hat = self.joystick.get_hat(0)
hat_up = hat[1] == 1 # Y+
hat_down = hat[1] == -1 # Y-
hat_left = hat[0] == -1 # X-
hat_right = hat[0] == 1 # X+
# 右摇杆控制Z
right_y_raw = -self.joystick.get_axis(4)
# 左摇杆控制RZ
left_y_raw = -self.joystick.get_axis(1)
# 应用死区
right_y = 0.0 if abs(right_y_raw) < self.deadzone else right_y_raw
left_y = 0.0 if abs(left_y_raw) < self.deadzone else left_y_raw
# 计算各轴速度
self.pose_speeds[1] = current_linear_step if hat_up else (-current_linear_step if hat_down else 0.0) # Y
self.pose_speeds[0] = -current_linear_step if hat_left else (current_linear_step if hat_right else 0.0) # X
# 设置Z速度右摇杆Y轴控制
z_mapping = self._apply_nonlinear_mapping(right_y)
self.pose_speeds[2] = z_mapping * current_linear_step # Z
# L1/R1控制RX旋转
LB = self.joystick.get_button(4) # RX-
RB = self.joystick.get_button(5) # RX+
self.pose_speeds[3] = (-current_angular_step if LB else (current_angular_step if RB else 0.0))
# △/□控制RY旋转
triangle = self.joystick.get_button(2) # RY+
square = self.joystick.get_button(3) # RY-
self.pose_speeds[4] = (current_angular_step if triangle else (-current_angular_step if square else 0.0))
# 左摇杆Y轴控制RZ旋转
rz_mapping = self._apply_nonlinear_mapping(left_y)
self.pose_speeds[5] = rz_mapping * current_angular_step * 2 # RZ
# 夹爪控制(圈/叉)
circle = self.joystick.get_button(1) # 夹爪开
cross = self.joystick.get_button(0) # 夹爪关
if circle:
self.gripper = min(self.max_gripper, self.gripper + self.gripper_speed)
elif cross:
self.gripper = max(self.min_gripper, self.gripper - self.gripper_speed)
# 更新末端位姿
for i in range(6):
self.pose[i] += self.pose_speeds[i]
# 角度归一化处理
for i in range(3, 6):
self.pose[i] = self._normalize_angle(self.pose[i])
def update_joint_state(self, joint):
self.joint = joint
# self.tozero = False
def update_endpose_state(self, end_pose):
self.pose = end_pose
# self.tozero = False
def update_tozero_state(self, tozero):
self.tozero = tozero
def get_action(self) -> Dict:
"""获取当前控制命令"""
return {
'control_mode': 'joint' if self.joint_control_mode else 'end_pose',
'use_master_arm': self.use_master_arm,
'master_joint_actions': self.master_joint_actions,
'master_controller_status': self.master_controller_status,
'end_pose': self.pose,
'joint_angles': self.joint,
'gripper': int(self.gripper),
'tozero': self.tozero
}
def stop(self):
"""停止控制器"""
self.running = False
if self.thread.is_alive():
self.thread.join()
if self.servo_arm:
self.servo_arm.close()
pygame.quit()
print("混合控制器已退出")
def reset(self):
"""重置到初始状态"""
self.joint = self.init_joint.copy()
self.pose = self.init_pose.copy()
self.pose_speeds = [0.0] * 6
self.joint_speeds = [0.0] * 6
self.gripper_speed = 10
self.gripper = self.min_gripper
print("已重置到初始状态")
# 使用示例
if __name__ == "__main__":
# 初始化睿尔曼机械臂
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
# 创建机械臂连接
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
print(f"机械臂连接ID: {handle.id}")
init_pose = arm.rm_get_current_arm_state()[1]['pose']
with open('/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/realman_mix.yaml', "r") as file:
config = yaml.safe_load(file)
config['init_pose'] = init_pose
arm_controller = HybridController(config)
try:
while True:
print(arm_controller.get_action())
time.sleep(0.1)
except KeyboardInterrupt:
arm_controller.stop()

View File

@@ -0,0 +1,4 @@
init_joint: [-90, 90, 90, -90, -90, 90]
max_gripper: 990
min_gripper: 10
servo_config_file: "/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_arm.yaml"

View File

@@ -0,0 +1,6 @@
port: /dev/ttyUSB0
right_port: /dev/ttyUSB1
baudrate: 460800
joint_hex_data: "55 AA 02 00 00 67"
control_hex_data: "55 AA 08 00 00 B9"
arm_axis: 6

View File

@@ -18,11 +18,9 @@ import os
import os.path as osp
import platform
import subprocess
import time
from copy import copy, deepcopy
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean
import numpy as np
import torch
@@ -109,17 +107,11 @@ def is_amp_available(device: str):
raise ValueError(f"Unknown device '{device}.")
def init_logging(log_file: Path | None = None, display_pid: bool = False):
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
# NOTE: Display PID is useful for multi-process logging.
if display_pid:
pid_str = f"[PID: {os.getpid()}]"
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
else:
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
return message
logging.basicConfig(level=logging.INFO)
@@ -133,12 +125,6 @@ def init_logging(log_file: Path | None = None, display_pid: bool = False):
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
if log_file is not None:
# Additionally write logs to file
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logging.getLogger().addHandler(file_handler)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]
@@ -189,7 +175,8 @@ def say(text, blocking=False):
cmd = ["say", text]
elif system == "Linux":
cmd = ["spd-say", text]
# cmd = ["spd-say", text]
cmd = ["edge-playback", "-t", text]
if blocking:
cmd.append("--wait")
@@ -242,114 +229,3 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
except TypeError:
# If a TypeError is raised, the string is not a valid dtype
return False
class TimerManager:
"""
Lightweight utility to measure elapsed time.
Examples
--------
```python
# Example 1: Using context manager
timer = TimerManager("Policy", log=False)
for _ in range(3):
with timer:
time.sleep(0.01)
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
```
```python
# Example 2: Using start/stop methods
timer = TimerManager("Policy", log=False)
timer.start()
time.sleep(0.01)
timer.stop()
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
```
"""
def __init__(
self,
label: str = "Elapsed-time",
log: bool = True,
logger: logging.Logger | None = None,
):
self.label = label
self.log = log
self.logger = logger
self._start: float | None = None
self._history: list[float] = []
def __enter__(self):
return self.start()
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def start(self):
self._start = time.perf_counter()
return self
def stop(self) -> float:
if self._start is None:
raise RuntimeError("Timer was never started.")
elapsed = time.perf_counter() - self._start
self._history.append(elapsed)
self._start = None
if self.log:
if self.logger is not None:
self.logger.info(f"{self.label}: {elapsed:.6f} s")
else:
logging.info(f"{self.label}: {elapsed:.6f} s")
return elapsed
def reset(self):
self._history.clear()
@property
def last(self) -> float:
return self._history[-1] if self._history else 0.0
@property
def avg(self) -> float:
return mean(self._history) if self._history else 0.0
@property
def total(self) -> float:
return sum(self._history)
@property
def count(self) -> int:
return len(self._history)
@property
def history(self) -> list[float]:
return deepcopy(self._history)
@property
def fps_history(self) -> list[float]:
return [1.0 / t for t in self._history]
@property
def fps_last(self) -> float:
return 0.0 if self.last == 0 else 1.0 / self.last
@property
def fps_avg(self) -> float:
return 0.0 if self.avg == 0 else 1.0 / self.avg
def percentile(self, p: float) -> float:
"""
Return the p-th percentile of recorded times.
"""
if not self._history:
return 0.0
return float(np.percentile(self._history, p))
def fps_percentile(self, p: float) -> float:
"""
FPS corresponding to the p-th percentile time.
"""
val = self.percentile(p)
return 0.0 if val == 0 else 1.0 / val

View File

@@ -30,10 +30,9 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.type}",
f"dataset:{cfg.dataset.repo_id}",
f"seed:{cfg.seed}",
]
if cfg.dataset is not None:
lst.append(f"dataset:{cfg.dataset.repo_id}")
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
return lst if return_list else "-".join(lst)
@@ -93,12 +92,6 @@ class WandBLogger:
resume="must" if cfg.resume else None,
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
)
run_id = wandb.run.id
# NOTE: We will override the cfg.wandb.run_id with the wandb run id.
# This is because we want to be able to resume the run from the wandb run id.
cfg.wandb.run_id = run_id
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@@ -115,26 +108,9 @@ class WandBLogger:
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
def log_dict(
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
):
def log_dict(self, d: dict, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:
raise ValueError(mode)
if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.")
# NOTE: This is not simple. Wandb step must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example,
# multiple time steps is possible. For example, the interaction step with the environment,
# the training step, the evaluation step, etc. So we need to define a custom step key
# to log the correct step for each metric.
if custom_step_key is not None:
if self._wandb_custom_step_key is None:
self._wandb_custom_step_key = set()
new_custom_key = f"{mode}/{custom_step_key}"
if new_custom_key not in self._wandb_custom_step_key:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items():
if not isinstance(v, (int, float, str)):
@@ -142,18 +118,7 @@ class WandBLogger:
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
# Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
self._wandb.log(data)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:

View File

@@ -34,10 +34,11 @@ TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
dataset: DatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None
job_name: str | None = None
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
@@ -106,7 +107,7 @@ class TrainPipelineConfig(HubMixin):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if self.dataset is not None and isinstance(self.dataset.repo_id, list):
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):

View File

@@ -23,7 +23,6 @@ class FeatureType(str, Enum):
VISUAL = "VISUAL"
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"
class NormalizationMode(str, Enum):

View File

@@ -273,7 +273,6 @@ def record(
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
if not robot.is_connected:
robot.connect()
@@ -290,6 +289,9 @@ def record(
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
# import pdb
# pdb.set_trace()
recorded_episodes = 0
while True:
if recorded_episodes >= cfg.num_episodes:

View File

@@ -1,722 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Actor server runner for distributed HILSerl robot policy training.
This script implements the actor component of the distributed HILSerl architecture.
It executes the policy in the robot environment, collects experience,
and sends transitions to the learner server for policy updates.
Examples of usage:
- Start an actor server for real robot training with human-in-the-loop intervention:
```bash
python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
```
- Run with a specific robot type for a pick and place task:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--robot.type=so100 \
--task=pick_and_place
```
- Set a custom workspace bound for the robot's end-effector:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
```
- Run with specific camera crop parameters:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
```
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
server is started before launching the actor.
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
reduce interventions as the policy improves.
**WORKFLOW**:
1. Determine robot workspace bounds using `find_joint_limits.py`
2. Record demonstrations with `gym_manipulator.py` in record mode
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
4. Start the learner server with the training configuration
5. Start this actor server with the same configuration
6. Use human interventions to guide policy learning
For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
import logging
import os
import time
from functools import lru_cache
from queue import Empty
import grpc
import torch
from torch import nn
from torch.multiprocessing import Event, Queue
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device,
init_logging,
)
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
from lerobot.scripts.server.buffer import Transition
from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.scripts.server.network_utils import (
bytes_to_state_dict,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.scripts.server.utils import (
get_last_item_from_queue,
move_state_dict_to_device,
move_transition_to_device,
setup_process_handlers,
)
ACTOR_SHUTDOWN_TIMEOUT = 30
#################################################
# Main entry point #
#################################################
@parser.wrap()
def actor_cli(cfg: TrainPipelineConfig):
cfg.validate()
display_pid = False
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
display_pid = True
# Create logs directory to ensure it exists
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=display_pid)
logging.info(f"Actor logging initialized, writing to {log_file}")
shutdown_event = setup_process_handlers(use_threads(cfg))
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
logging.info("[ACTOR] Establishing connection with Learner")
if not establish_learner_connection(learner_client, shutdown_event):
logging.error("[ACTOR] Failed to establish connection with Learner")
return
if not use_threads(cfg):
# If we use multithreading, we can reuse the channel
grpc_channel.close()
grpc_channel = None
logging.info("[ACTOR] Connection with Learner established")
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
concurrency_entity = None
if use_threads(cfg):
from threading import Thread
concurrency_entity = Thread
else:
from multiprocessing import Process
concurrency_entity = Process
receive_policy_process = concurrency_entity(
target=receive_policy,
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process = concurrency_entity(
target=send_transitions,
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
daemon=True,
)
interactions_process = concurrency_entity(
target=send_interactions,
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process.start()
interactions_process.start()
receive_policy_process.start()
act_with_policy(
cfg=cfg,
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
transitions_queue=transitions_queue,
interactions_queue=interactions_queue,
)
logging.info("[ACTOR] Policy process joined")
logging.info("[ACTOR] Closing queues")
transitions_queue.close()
interactions_queue.close()
parameters_queue.close()
transitions_process.join()
logging.info("[ACTOR] Transitions process joined")
interactions_process.join()
logging.info("[ACTOR] Interactions process joined")
receive_policy_process.join()
logging.info("[ACTOR] Receive policy process joined")
logging.info("[ACTOR] join queues")
transitions_queue.cancel_join_thread()
interactions_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[ACTOR] queues closed")
#################################################
# Core algorithm functions #
#################################################
def act_with_policy(
cfg: TrainPipelineConfig,
shutdown_event: any, # Event,
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
):
"""
Executes policy interaction within the environment.
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
Args:
cfg: Configuration settings for the interaction process.
shutdown_event: Event to check if the process should shutdown.
parameters_queue: Queue to receive updated network parameters from the learner.
transitions_queue: Queue to send transitions to the learner.
interactions_queue: Queue to send interactions to the learner.
"""
# Initialize logging for multiprocessing
if not use_threads(cfg):
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor policy process logging initialized")
logging.info("make_env online")
online_env = make_robot_env(cfg=cfg.env)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy: SACPolicy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
list_transition_to_send_to_learner = []
episode_intervention = False
# Add counters for intervention rate calculation
episode_intervention_steps = 0
episode_total_steps = 0
policy_timer = TimerManager("Policy inference", log=False)
for interaction_step in range(cfg.policy.online_steps):
start_time = time.perf_counter()
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down act_with_policy")
return
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with policy_timer:
action = policy.select_action(batch=obs)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate
episode_total_steps += 1
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
episode_intervention = True
# Increment intervention steps counter
episode_intervention_steps += 1
list_transition_to_send_to_learner.append(
Transition(
state=obs,
action=action,
reward=reward,
next_state=next_obs,
done=done,
truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info,
)
)
# assign obs to the next obs and continue the rollout
obs = next_obs
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
transitions=list_transition_to_send_to_learner,
transitions_queue=transitions_queue,
)
list_transition_to_send_to_learner = []
stats = get_frequency_stats(policy_timer)
policy_timer.reset()
# Calculate intervention rate
intervention_rate = 0.0
if episode_total_steps > 0:
intervention_rate = episode_intervention_steps / episode_total_steps
# Send episodic reward to the learner
interactions_queue.put(
python_object_to_bytes(
{
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
"Episode intervention": int(episode_intervention),
"Intervention rate": intervention_rate,
**stats,
}
)
)
# Reset intervention counters
sum_reward_episode = 0.0
episode_intervention = False
episode_intervention_steps = 0
episode_total_steps = 0
obs, info = online_env.reset()
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.env.fps - dt_time)
#################################################
# Communication Functions - Group all gRPC/messaging functions #
#################################################
def establish_learner_connection(
stub: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
attempts: int = 30,
):
"""Establish a connection with the learner.
Args:
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
shutdown_event (Event): The event to check if the connection should be established.
attempts (int): The number of attempts to establish the connection.
Returns:
bool: True if the connection is established, False otherwise.
"""
for _ in range(attempts):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down establish_learner_connection")
return False
# Force a connection attempt and check state
try:
logging.info("[ACTOR] Send ready message to Learner")
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
return True
except grpc.RpcError as e:
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
time.sleep(2)
return False
@lru_cache(maxsize=1)
def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
"""
Returns a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
"""
service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}
service_config_json = json.dumps(service_config)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.enable_retries", 1),
("grpc.service_config", service_config_json),
],
)
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
logging.info("[ACTOR] Learner service client created")
return stub, channel
def receive_policy(
cfg: TrainPipelineConfig,
parameters_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
"""Receive parameters from the learner.
Args:
cfg (TrainPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor receive policy process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(use_threads=False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
receive_bytes_in_chunks(
iterator,
parameters_queue,
shutdown_event,
log_prefix="[ACTOR] parameters",
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Received policy loop stopped")
def send_transitions(
cfg: TrainPipelineConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
"""
Sends transitions to the learner.
This function continuously retrieves messages from the queue and processes:
- Transition Data:
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
"""
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor transitions process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming transitions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Transitions process stopped")
def send_interactions(
cfg: TrainPipelineConfig,
interactions_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
"""
Sends interactions to the learner.
This function continuously retrieves messages from the queue and processes:
- Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
"""
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor interactions process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming interactions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=5)
except Empty:
logging.debug("[ACTOR] Transition queue is empty")
continue
yield from send_bytes_in_chunks(
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
)
return hilserl_pb2.Empty()
def interactions_stream(
shutdown_event: Event, # type: ignore
interactions_queue: Queue,
) -> hilserl_pb2.Empty:
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=5)
except Empty:
logging.debug("[ACTOR] Interaction queue is empty")
continue
yield from send_bytes_in_chunks(
message,
hilserl_pb2.InteractionMessage,
log_prefix="[ACTOR] Send interactions",
)
return hilserl_pb2.Empty()
#################################################
# Policy functions #
#################################################
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
if not parameters_queue.empty():
logging.info("[ACTOR] Load new parameters from Learner.")
bytes_state_dict = get_last_item_from_queue(parameters_queue)
state_dict = bytes_to_state_dict(bytes_state_dict)
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict)
#################################################
# Utilities functions #
#################################################
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
"""Send transitions to learner in smaller chunks to avoid network issues.
Args:
transitions: List of transitions to send
message_queue: Queue to send messages to learner
chunk_size: Size of each chunk to send
"""
transition_to_send_to_learner = []
for transition in transitions:
tr = move_transition_to_device(transition=transition, device="cpu")
for key, value in tr["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
transition_to_send_to_learner.append(tr)
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
"""Get the frequency statistics of the policy.
Args:
timer (TimerManager): The timer with collected metrics.
Returns:
dict[str, float]: The frequency statistics of the policy.
"""
stats = {}
if timer.count > 1:
avg_fps = timer.fps_avg
p90_fps = timer.fps_percentile(90)
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
stats = {
"Policy frequency [Hz]": avg_fps,
"Policy frequency 90th-p [Hz]": p90_fps,
}
return stats
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
if policy_fps < cfg.env.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
)
def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency.actor == "threads"
if __name__ == "__main__":
actor_cli()

View File

@@ -1,820 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from contextlib import suppress
from typing import Callable, Optional, Sequence, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.scripts.server.utils import Transition
class BatchTransition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: torch.Tensor
next_state: dict[str, torch.Tensor]
done: torch.Tensor
truncated: torch.Tensor
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
"""
Perform a per-image random crop over a batch of images in a vectorized way.
(Same as shown previously.)
"""
B, C, H, W = images.shape # noqa: N806
crop_h, crop_w = output_size
if crop_h > H or crop_w > W:
raise ValueError(
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
)
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
return cropped
def random_shift(images: torch.Tensor, pad: int = 4):
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
_, _, h, w = images.shape
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
return random_crop_vectorized(images=images, output_size=(h, w))
class ReplayBuffer:
def __init__(
self,
capacity: int,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
optimize_memory: bool = False,
):
"""
Replay buffer for storing transitions.
It will allocate tensors on the specified device, when the first transition is added.
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
and use the `storage_device` flag to store the buffer on a different device.
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
Using "cpu" can help save GPU memory.
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
"""
if capacity <= 0:
raise ValueError("Capacity must be greater than 0.")
self.capacity = capacity
self.device = device
self.storage_device = storage_device
self.position = 0
self.size = 0
self.initialized = False
self.optimize_memory = optimize_memory
# Track episode boundaries for memory optimization
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
# If no state_keys provided, default to an empty list
self.state_keys = state_keys if state_keys is not None else []
self.image_augmentation_function = image_augmentation_function
if image_augmentation_function is None:
base_function = functools.partial(random_shift, pad=4)
self.image_augmentation_function = torch.compile(base_function)
self.use_drq = use_drq
def _initialize_storage(
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Initialize the storage tensors based on the first transition."""
# Determine shapes from the first transition
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
action_shape = action.squeeze(0).shape
# Pre-allocate tensors for storage
self.states = {
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
if not self.optimize_memory:
# Standard approach: store states and next_states separately
self.next_states = {
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
else:
# Memory-optimized approach: don't allocate next_states buffer
# Just create a reference to states for consistent API
self.next_states = self.states # Just a reference for API consistency
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
# Initialize storage for complementary_info
self.has_complementary_info = complementary_info is not None
self.complementary_info_keys = []
self.complementary_info = {}
if self.has_complementary_info:
self.complementary_info_keys = list(complementary_info.keys())
# Pre-allocate tensors for each key in complementary_info
for key, value in complementary_info.items():
if isinstance(value, torch.Tensor):
value_shape = value.squeeze(0).shape
self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device
)
elif isinstance(value, (int, float)):
# Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else:
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
self.initialized = True
def __len__(self):
return self.size
def add(
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
reward: float,
next_state: dict[str, torch.Tensor],
done: bool,
truncated: bool,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
# Initialize storage if this is the first transition
if not self.initialized:
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
# Store the transition in pre-allocated tensors
for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
if not self.optimize_memory:
# Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
self.dones[self.position] = done
self.truncateds[self.position] = truncated
# Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info:
# Store the complementary_info
for key in self.complementary_info_keys:
if key in complementary_info:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int, float)):
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
if not self.initialized:
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
batch_size = min(batch_size, self.size)
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
# Random indices for sampling - create on the same device as storage
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
# Create batched state and next_state
batch_state = {}
batch_next_state = {}
# First pass: load all state tensors to target device
for key in self.states:
batch_state[key] = self.states[key][idx].to(self.device)
if not self.optimize_memory:
# Standard approach - load next_states directly
batch_next_state[key] = self.next_states[key][idx].to(self.device)
else:
# Memory-optimized approach - get next_state from the next index
next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device)
# Apply image augmentation in a batched way if needed
if self.use_drq and image_keys:
# Concatenate all images from state and next_state
all_images = []
for key in image_keys:
all_images.append(batch_state[key])
all_images.append(batch_next_state[key])
# Optimization: Batch all images and apply augmentation once
all_images_tensor = torch.cat(all_images, dim=0)
augmented_images = self.image_augmentation_function(all_images_tensor)
# Split the augmented images back to their sources
for i, key in enumerate(image_keys):
# Calculate offsets for the current image key:
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
# States start at index i*2*batch_size and take up batch_size slots
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
return BatchTransition(
state=batch_state,
action=batch_actions,
reward=batch_rewards,
next_state=batch_next_state,
done=batch_dones,
truncated=batch_truncateds,
complementary_info=batch_complementary_info,
)
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""
Creates an infinite iterator that yields batches of transitions.
Will automatically restart when internal iterator is exhausted.
Args:
batch_size (int): Size of batches to sample
async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True)
queue_size (int): Number of batches to prefetch (default: 2)
Yields:
BatchTransition: Batched transitions
"""
while True: # Create an infinite loop
if async_prefetch:
# Get the standard iterator
iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size)
else:
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
# Yield all items from the iterator
with suppress(StopIteration):
yield from iterator
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
"""
Creates an iterator that prefetches batches in a background thread.
Args:
queue_size (int): Number of batches to prefetch (default: 2)
batch_size (int): Size of batches to sample (default: 128)
Yields:
BatchTransition: Prefetched batch transitions
"""
import queue
import threading
# Use thread-safe queue
data_queue = queue.Queue(maxsize=queue_size)
running = [True] # Use list to allow modification in nested function
def prefetch_worker():
while running[0]:
try:
# Sample data and add to queue
data = self.sample(batch_size)
data_queue.put(data, block=True, timeout=0.5)
except queue.Full:
continue
except Exception as e:
print(f"Prefetch error: {e}")
break
# Start prefetching thread
thread = threading.Thread(target=prefetch_worker, daemon=True)
thread.start()
try:
while running[0]:
try:
yield data_queue.get(block=True, timeout=0.5)
except queue.Empty:
if not thread.is_alive():
break
finally:
# Clean up
running[0] = False
thread.join(timeout=1.0)
def _get_naive_iterator(self, batch_size: int, queue_size: int = 2):
"""
Creates a simple non-threaded iterator that yields batches.
Args:
batch_size (int): Size of batches to sample
queue_size (int): Number of initial batches to prefetch
Yields:
BatchTransition: Batch transitions
"""
import collections
queue = collections.deque()
def enqueue(n):
for _ in range(n):
data = self.sample(batch_size)
queue.append(data)
enqueue(queue_size)
while queue:
yield queue.popleft()
enqueue(1)
@classmethod
def from_lerobot_dataset(
cls,
lerobot_dataset: LeRobotDataset,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
optimize_memory: bool = False,
) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
Args:
lerobot_dataset (LeRobotDataset): The dataset to convert.
device (str): The device for sampling tensors. Defaults to "cuda:0".
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
image_augmentation_function (Optional[Callable]): Function for image augmentation.
If None, uses default random shift with pad=4.
use_drq (bool): Whether to use DrQ image augmentation when sampling.
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
Returns:
ReplayBuffer: The replay buffer with dataset transitions.
"""
if capacity is None:
capacity = len(lerobot_dataset)
if capacity < len(lerobot_dataset):
raise ValueError(
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
)
# Create replay buffer with image augmentation and DrQ settings
replay_buffer = cls(
capacity=capacity,
device=device,
state_keys=state_keys,
image_augmentation_function=image_augmentation_function,
use_drq=use_drq,
storage_device=storage_device,
optimize_memory=optimize_memory,
)
# Convert dataset to transitions
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
# Initialize the buffer with the first transition to set up storage tensors
if list_transition:
first_transition = list_transition[0]
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition["action"].to(device)
# Get complementary info if available
first_complementary_info = None
if (
"complementary_info" in first_transition
and first_transition["complementary_info"] is not None
):
first_complementary_info = {
k: v.to(device) for k, v in first_transition["complementary_info"].items()
}
replay_buffer._initialize_storage(
state=first_state, action=first_action, complementary_info=first_complementary_info
)
# Fill the buffer with all transitions
for data in list_transition:
for k, v in data.items():
if isinstance(v, dict):
for key, tensor in v.items():
v[key] = tensor.to(storage_device)
elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device)
action = data["action"]
replay_buffer.add(
state=data["state"],
action=action,
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
complementary_info=data.get("complementary_info", None),
)
return replay_buffer
def to_lerobot_dataset(
self,
repo_id: str,
fps=1,
root=None,
task_name="from_replay_buffer",
) -> LeRobotDataset:
"""
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object.
"""
if self.size == 0:
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
# Create features dictionary for the dataset
features = {
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
"task_index": {"dtype": "int64", "shape": [1]},
}
# Add "action"
sample_action = self.actions[0]
act_info = guess_feature_info(t=sample_action, name="action")
features["action"] = act_info
# Add "reward" and "done"
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
features["next.done"] = {"dtype": "bool", "shape": (1,)}
# Add state keys
for key in self.states:
sample_val = self.states[key][0]
f_info = guess_feature_info(t=sample_val, name=key)
features[key] = f_info
# Add complementary_info keys if available
if self.has_complementary_info:
for key in self.complementary_info_keys:
sample_val = self.complementary_info[key][0]
if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
sample_val = sample_val.unsqueeze(0)
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
features[f"complementary_info.{key}"] = f_info
# Create an empty LeRobotDataset
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=fps,
root=root,
robot=None, # TODO: (azouitine) Handle robot
robot_type=None,
features=features,
use_videos=True,
)
# Start writing images if needed
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
# Convert transitions into episodes and frames
episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
frame_idx_in_episode = 0
for idx in range(self.size):
actual_idx = (self.position - self.size + idx) % self.capacity
frame_dict = {}
# Fill the data for state keys
for key in self.states:
frame_dict[key] = self.states[key][actual_idx].cpu()
# Fill action, reward, done
frame_dict["action"] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
# Add complementary_info if available
if self.has_complementary_info:
for key in self.complementary_info_keys:
val = self.complementary_info[key][actual_idx]
# Convert tensors to CPU
if isinstance(val, torch.Tensor):
if val.ndim == 0:
val = val.unsqueeze(0)
frame_dict[f"complementary_info.{key}"] = val.cpu()
# Non-tensor values can be used directly
else:
frame_dict[f"complementary_info.{key}"] = val
# Add task field which is required by LeRobotDataset
frame_dict["task"] = task_name
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict)
# Move to next frame
frame_idx_in_episode += 1
# If we reached an episode boundary, call save_episode, reset counters
if self.dones[actual_idx] or self.truncateds[actual_idx]:
lerobot_dataset.save_episode()
episode_index += 1
frame_idx_in_episode = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index=episode_index
)
# Save any remaining frames in the buffer
if lerobot_dataset.episode_buffer["size"] > 0:
lerobot_dataset.save_episode()
lerobot_dataset.stop_image_writer()
return lerobot_dataset
@staticmethod
def _lerobotdataset_to_transitions(
dataset: LeRobotDataset,
state_keys: Optional[Sequence[str]] = None,
) -> list[Transition]:
"""
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
Args:
dataset (LeRobotDataset):
The dataset to convert. Each item in the dataset is expected to have
at least the following keys:
{
"action": ...
"next.reward": ...
"next.done": ...
"episode_index": ...
}
plus whatever your 'state_keys' specify.
state_keys (Optional[Sequence[str]]):
The dataset keys to include in 'state' and 'next_state'. Their names
will be kept as-is in the output transitions. E.g.
["observation.state", "observation.environment_state"].
If None, you must handle or define default keys.
Returns:
transitions (List[Transition]):
A list of Transition dictionaries with the same length as `dataset`.
"""
if state_keys is None:
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
transitions = []
num_frames = len(dataset)
# Check if the dataset has "next.done" key
sample = dataset[0]
has_done_key = "next.done" in sample
# Check for complementary_info keys
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
has_complementary_info = len(complementary_info_keys) > 0
# If not, we need to infer it from episode boundaries
if not has_done_key:
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
for i in tqdm(range(num_frames)):
current_sample = dataset[i]
# ----- 1) Current state -----
current_state: dict[str, torch.Tensor] = {}
for key in state_keys:
val = current_sample[key]
current_state[key] = val.unsqueeze(0) # Add batch dimension
# ----- 2) Action -----
action = current_sample["action"].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
if has_done_key:
done = bool(current_sample["next.done"].item()) # ensure bool
else:
# If this is the last frame or if next frame is in a different episode, mark as done
done = False
if i == num_frames - 1:
done = True
elif i < num_frames - 1:
next_sample = dataset[i + 1]
if next_sample["episode_index"] != current_sample["episode_index"]:
done = True
# TODO: (azouitine) Handle truncation (using the same value as done for now)
truncated = done
# ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state.
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
next_state = current_state # default
if not done and (i < num_frames - 1):
next_sample = dataset[i + 1]
if next_sample["episode_index"] == current_sample["episode_index"]:
# Build next_state from the same keys
next_state_data: dict[str, torch.Tensor] = {}
for key in state_keys:
val = next_sample[key]
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
next_state = next_state_data
# ----- 5) Complementary info (if available) -----
complementary_info = None
if has_complementary_info:
complementary_info = {}
for key in complementary_info_keys:
# Strip the "complementary_info." prefix to get the actual key
clean_key = key[len("complementary_info.") :]
val = current_sample[key]
# Handle tensor and non-tensor values differently
if isinstance(val, torch.Tensor):
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
else:
# TODO: (azouitine) Check if it's necessary to convert to tensor
# For non-tensor values, use directly
complementary_info[clean_key] = val
# ----- Construct the Transition -----
transition = Transition(
state=current_state,
action=action,
reward=reward,
next_state=next_state,
done=done,
truncated=truncated,
complementary_info=complementary_info,
)
transitions.append(transition)
return transitions
# Utility function to guess shapes/dtypes from a tensor
def guess_feature_info(t, name: str):
"""
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
Otherwise default to appropriate dtype for numeric.
"""
shape = tuple(t.shape)
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
if len(shape) == 3 and shape[0] in [1, 3]:
return {
"dtype": "image",
"shape": shape,
}
else:
# Otherwise treat as numeric
return {
"dtype": "float32",
"shape": shape,
}
def concatenate_batch_transitions(
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
# Concatenate state fields
left_batch_transitions["state"] = {
key: torch.cat(
[left_batch_transitions["state"][key], right_batch_transition["state"][key]],
dim=0,
)
for key in left_batch_transitions["state"]
}
# Concatenate basic fields
left_batch_transitions["action"] = torch.cat(
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
)
left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
)
# Concatenate next_state fields
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
dim=0,
)
for key in left_batch_transitions["next_state"]
}
# Concatenate done and truncated fields
left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
)
left_batch_transitions["truncated"] = torch.cat(
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
dim=0,
)
# Handle complementary_info
left_info = left_batch_transitions.get("complementary_info")
right_info = right_batch_transition.get("complementary_info")
# Only process if right_info exists
if right_info is not None:
# Initialize left complementary_info if needed
if left_info is None:
left_batch_transitions["complementary_info"] = right_info
else:
# Concatenate each field
for key in right_info:
if key in left_info:
left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
else:
left_info[key] = right_info[key]
return left_batch_transitions

View File

@@ -1,303 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
from copy import deepcopy
from pathlib import Path
from typing import Dict, Tuple
import cv2
# import torch.nn.functional as F # noqa: N812
import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def select_rect_roi(img):
"""
Allows the user to draw a rectangular ROI on the image.
The user must click and drag to draw the rectangle.
- While dragging, the rectangle is dynamically drawn.
- On mouse button release, the rectangle is fixed.
- Press 'c' to confirm the selection.
- Press 'r' to reset the selection.
- Press ESC to cancel.
Returns:
A tuple (top, left, height, width) representing the rectangular ROI,
or None if no valid ROI is selected.
"""
# Create a working copy of the image
clone = img.copy()
working_img = clone.copy()
roi = None # Will store the final ROI as (top, left, height, width)
drawing = False
index_x, index_y = -1, -1 # Initial click coordinates
def mouse_callback(event, x, y, flags, param):
nonlocal index_x, index_y, drawing, roi, working_img
if event == cv2.EVENT_LBUTTONDOWN:
# Start drawing: record starting coordinates
drawing = True
index_x, index_y = x, y
elif event == cv2.EVENT_MOUSEMOVE:
if drawing:
# Compute the top-left and bottom-right corners regardless of drag direction
top = min(index_y, y)
left = min(index_x, x)
bottom = max(index_y, y)
right = max(index_x, x)
# Show a temporary image with the current rectangle drawn
temp = working_img.copy()
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", temp)
elif event == cv2.EVENT_LBUTTONUP:
# Finish drawing
drawing = False
top = min(index_y, y)
left = min(index_x, x)
bottom = max(index_y, y)
right = max(index_x, x)
height = bottom - top
width = right - left
roi = (top, left, height, width) # (top, left, height, width)
# Draw the final rectangle on the working image and display it
working_img = clone.copy()
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", working_img)
# Create the window and set the callback
cv2.namedWindow("Select ROI")
cv2.setMouseCallback("Select ROI", mouse_callback)
cv2.imshow("Select ROI", working_img)
print("Instructions for ROI selection:")
print(" - Click and drag to draw a rectangular ROI.")
print(" - Press 'c' to confirm the selection.")
print(" - Press 'r' to reset and draw again.")
print(" - Press ESC to cancel the selection.")
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
while True:
key = cv2.waitKey(1) & 0xFF
# Confirm ROI if one has been drawn
if key == ord("c") and roi is not None:
break
# Reset: clear the ROI and restore the original image
elif key == ord("r"):
working_img = clone.copy()
roi = None
cv2.imshow("Select ROI", working_img)
# Cancel selection for this image
elif key == 27: # ESC key
roi = None
break
cv2.destroyWindow("Select ROI")
return roi
def select_square_roi_for_images(images: dict) -> dict:
"""
For each image in the provided dictionary, open a window to allow the user
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
(top, left, height, width) representing the ROI.
Parameters:
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
Returns:
dict: Mapping of image keys to the selected rectangular ROI.
"""
selected_rois = {}
for key, img in images.items():
if img is None:
print(f"Image for key '{key}' is None, skipping.")
continue
print(f"\nSelect rectangular ROI for image with key: '{key}'")
roi = select_rect_roi(img)
if roi is None:
print(f"No valid ROI selected for '{key}'.")
else:
selected_rois[key] = roi
print(f"ROI for '{key}': {roi}")
return selected_rois
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
"""
Find the first row in the dataset and extract the image in order to be used for the crop.
"""
row = dataset[0]
image_dict = {}
for k in row:
if "image" in k:
image_dict[k] = deepcopy(row[k])
return image_dict
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset: LeRobotDataset,
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
new_repo_id: str,
new_dataset_root: str,
resize_size: Tuple[int, int] = (128, 128),
push_to_hub: bool = False,
) -> LeRobotDataset:
"""
Converts an existing LeRobotDataset by iterating over its episodes and frames,
applying cropping and resizing to image observations, and saving a new dataset
with the transformed data.
Args:
original_dataset (LeRobotDataset): The source dataset.
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
A dictionary mapping observation keys to crop parameters (top, left, height, width).
new_repo_id (str): Repository id for the new dataset.
new_dataset_root (str): The root directory where the new dataset will be written.
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
Defaults to (128, 128).
Returns:
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
and resized.
"""
# 1. Create a new (empty) LeRobotDataset for writing.
new_dataset = LeRobotDataset.create(
repo_id=new_repo_id,
fps=original_dataset.fps,
root=new_dataset_root,
robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"],
use_videos=len(original_dataset.meta.video_keys) > 0,
)
# Update the metadata for every image key that will be cropped:
# (Here we simply set the shape to be the final resize_size.)
for key in crop_params_dict:
if key in new_dataset.meta.info["features"]:
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
prev_episode_index = 0
for frame_idx in tqdm(range(len(original_dataset))):
frame = original_dataset[frame_idx]
# Create a copy of the frame to add to the new dataset
new_frame = {}
for key, value in frame.items():
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index"):
continue
if key in ("next.done", "next.reward"):
# if not isinstance(value, str) and len(value.shape) == 0:
value = value.unsqueeze(0)
if key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
# Apply crop then resize.
cropped = F.crop(value, top, left, height, width)
value = F.resize(cropped, resize_size)
value = value.clamp(0, 1)
new_frame[key] = value
new_dataset.add_frame(new_frame)
if frame["episode_index"].item() != prev_episode_index:
# Save the episode
new_dataset.save_episode()
prev_episode_index = frame["episode_index"].item()
if push_to_hub:
new_dataset.push_to_hub()
return new_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser.add_argument(
"--repo-id",
type=str,
default="lerobot",
help="The repository id of the LeRobot dataset to process.",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="The root directory of the LeRobot dataset.",
)
parser.add_argument(
"--crop-params-path",
type=str,
default=None,
help="The path to the JSON file containing the ROIs.",
)
parser.add_argument(
"--push-to-hub",
type=bool,
default=False,
help="Whether to push the new dataset to the hub.",
)
args = parser.parse_args()
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path) as f:
rois = json.load(f)
# Print the selected rectangular ROIs
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
for key, roi in rois.items():
print(f"{key}: {roi}")
new_repo_id = args.repo_id + "_cropped_resized"
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=new_repo_id,
new_dataset_root=new_dataset_root,
resize_size=(128, 128),
push_to_hub=args.push_to_hub,
)
meta_dir = new_dataset_root / "meta"
meta_dir.mkdir(exist_ok=True)
with open(meta_dir / "crop_params.json", "w") as f:
json.dump(rois, f, indent=4)

View File

@@ -1,802 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import sys
import time
import numpy as np
import torch
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.server.kinematics import RobotKinematics
class InputController:
"""Base class for input controllers that generate motion deltas."""
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
"""
Initialize the controller.
Args:
x_step_size: Base movement step size in meters
y_step_size: Base movement step size in meters
z_step_size: Base movement step size in meters
"""
self.x_step_size = x_step_size
self.y_step_size = y_step_size
self.z_step_size = z_step_size
self.running = True
self.episode_end_status = None # None, "success", or "failure"
self.intervention_flag = False
self.open_gripper_command = False
self.close_gripper_command = False
def start(self):
"""Start the controller and initialize resources."""
pass
def stop(self):
"""Stop the controller and release resources."""
pass
def get_deltas(self):
"""Get the current movement deltas (dx, dy, dz) in meters."""
return 0.0, 0.0, 0.0
def should_quit(self):
"""Return True if the user has requested to quit."""
return not self.running
def update(self):
"""Update controller state - call this once per frame."""
pass
def __enter__(self):
"""Support for use in 'with' statements."""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Ensure resources are released when exiting 'with' block."""
self.stop()
def get_episode_end_status(self):
"""
Get the current episode end status.
Returns:
None if episode should continue, "success" or "failure" otherwise
"""
status = self.episode_end_status
self.episode_end_status = None # Reset after reading
return status
def should_intervene(self):
"""Return True if intervention flag was set."""
return self.intervention_flag
def gripper_command(self):
"""Return the current gripper command."""
if self.open_gripper_command == self.close_gripper_command:
return "no-op"
elif self.open_gripper_command:
return "open"
elif self.close_gripper_command:
return "close"
class KeyboardController(InputController):
"""Generate motion deltas from keyboard input."""
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
super().__init__(x_step_size, y_step_size, z_step_size)
self.key_states = {
"forward_x": False,
"backward_x": False,
"forward_y": False,
"backward_y": False,
"forward_z": False,
"backward_z": False,
"quit": False,
"success": False,
"failure": False,
}
self.listener = None
def start(self):
"""Start the keyboard listener."""
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.up:
self.key_states["forward_x"] = True
elif key == keyboard.Key.down:
self.key_states["backward_x"] = True
elif key == keyboard.Key.left:
self.key_states["forward_y"] = True
elif key == keyboard.Key.right:
self.key_states["backward_y"] = True
elif key == keyboard.Key.shift:
self.key_states["backward_z"] = True
elif key == keyboard.Key.shift_r:
self.key_states["forward_z"] = True
elif key == keyboard.Key.esc:
self.key_states["quit"] = True
self.running = False
return False
elif key == keyboard.Key.enter:
self.key_states["success"] = True
self.episode_end_status = "success"
elif key == keyboard.Key.backspace:
self.key_states["failure"] = True
self.episode_end_status = "failure"
except AttributeError:
pass
def on_release(key):
try:
if key == keyboard.Key.up:
self.key_states["forward_x"] = False
elif key == keyboard.Key.down:
self.key_states["backward_x"] = False
elif key == keyboard.Key.left:
self.key_states["forward_y"] = False
elif key == keyboard.Key.right:
self.key_states["backward_y"] = False
elif key == keyboard.Key.shift:
self.key_states["backward_z"] = False
elif key == keyboard.Key.shift_r:
self.key_states["forward_z"] = False
elif key == keyboard.Key.enter:
self.key_states["success"] = False
elif key == keyboard.Key.backspace:
self.key_states["failure"] = False
except AttributeError:
pass
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
self.listener.start()
print("Keyboard controls:")
print(" Arrow keys: Move in X-Y plane")
print(" Shift and Shift_R: Move in Z axis")
print(" Enter: End episode with SUCCESS")
print(" Backspace: End episode with FAILURE")
print(" ESC: Exit")
def stop(self):
"""Stop the keyboard listener."""
if self.listener and self.listener.is_alive():
self.listener.stop()
def get_deltas(self):
"""Get the current movement deltas from keyboard state."""
delta_x = delta_y = delta_z = 0.0
if self.key_states["forward_x"]:
delta_x += self.x_step_size
if self.key_states["backward_x"]:
delta_x -= self.x_step_size
if self.key_states["forward_y"]:
delta_y += self.y_step_size
if self.key_states["backward_y"]:
delta_y -= self.y_step_size
if self.key_states["forward_z"]:
delta_z += self.z_step_size
if self.key_states["backward_z"]:
delta_z -= self.z_step_size
return delta_x, delta_y, delta_z
def should_quit(self):
"""Return True if ESC was pressed."""
return self.key_states["quit"]
def should_save(self):
"""Return True if Enter was pressed (save episode)."""
return self.key_states["success"] or self.key_states["failure"]
class GamepadController(InputController):
"""Generate motion deltas from gamepad input."""
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1):
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.joystick = None
self.intervention_flag = False
def start(self):
"""Initialize pygame and the gamepad."""
import pygame
pygame.init()
pygame.joystick.init()
if pygame.joystick.get_count() == 0:
logging.error("No gamepad detected. Please connect a gamepad and try again.")
self.running = False
return
self.joystick = pygame.joystick.Joystick(0)
self.joystick.init()
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
print("Gamepad controls:")
print(" Left analog stick: Move in X-Y plane")
print(" Right analog stick (vertical): Move in Z axis")
print(" B/Circle button: Exit")
print(" Y/Triangle button: End episode with SUCCESS")
print(" A/Cross button: End episode with FAILURE")
print(" X/Square button: Rerecord episode")
def stop(self):
"""Clean up pygame resources."""
import pygame
if pygame.joystick.get_init():
if self.joystick:
self.joystick.quit()
pygame.joystick.quit()
pygame.quit()
def update(self):
"""Process pygame events to get fresh gamepad readings."""
import pygame
for event in pygame.event.get():
if event.type == pygame.JOYBUTTONDOWN:
if event.button == 3:
self.episode_end_status = "success"
# A button (1) for failure
elif event.button == 1:
self.episode_end_status = "failure"
# X button (0) for rerecord
elif event.button == 0:
self.episode_end_status = "rerecord_episode"
# RB button (6) for closing gripper
elif event.button == 6:
self.close_gripper_command = True
# LT button (7) for opening gripper
elif event.button == 7:
self.open_gripper_command = True
# Reset episode status on button release
elif event.type == pygame.JOYBUTTONUP:
if event.button in [0, 2, 3]:
self.episode_end_status = None
elif event.button == 6:
self.close_gripper_command = False
elif event.button == 7:
self.open_gripper_command = False
# Check for RB button (typically button 5) for intervention flag
if self.joystick.get_button(5):
self.intervention_flag = True
else:
self.intervention_flag = False
def get_deltas(self):
"""Get the current movement deltas from gamepad state."""
import pygame
try:
# Read joystick axes
# Left stick X and Y (typically axes 0 and 1)
x_input = self.joystick.get_axis(0) # Left/Right
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
# Right stick Y (typically axis 3 or 4)
z_input = self.joystick.get_axis(3) # Up/Down for Z
# Apply deadzone to avoid drift
x_input = 0 if abs(x_input) < self.deadzone else x_input
y_input = 0 if abs(y_input) < self.deadzone else y_input
z_input = 0 if abs(z_input) < self.deadzone else z_input
# Calculate deltas (note: may need to invert axes depending on controller)
delta_x = -y_input * self.y_step_size # Forward/backward
delta_y = -x_input * self.x_step_size # Left/right
delta_z = -z_input * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
except pygame.error:
logging.error("Error reading gamepad. Is it still connected?")
return 0.0, 0.0, 0.0
class GamepadControllerHID(InputController):
"""Generate motion deltas from gamepad input using HIDAPI."""
def __init__(
self,
x_step_size=0.01,
y_step_size=0.01,
z_step_size=0.01,
deadzone=0.1,
vendor_id=0x046D,
product_id=0xC219,
):
"""
Initialize the HID gamepad controller.
Args:
step_size: Base movement step size in meters
z_scale: Scaling factor for Z-axis movement
deadzone: Joystick deadzone to prevent drift
vendor_id: USB vendor ID of the gamepad (default: Logitech)
product_id: USB product ID of the gamepad (default: RumblePad 2)
"""
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.vendor_id = vendor_id
self.product_id = product_id
self.device = None
self.device_info = None
# Movement values (normalized from -1.0 to 1.0)
self.left_x = 0.0
self.left_y = 0.0
self.right_x = 0.0
self.right_y = 0.0
# Button states
self.buttons = {}
self.quit_requested = False
self.save_requested = False
def find_device(self):
"""Look for the gamepad device by vendor and product ID."""
import hid
devices = hid.enumerate()
for device in devices:
if device["vendor_id"] == self.vendor_id and device["product_id"] == self.product_id:
logging.info(f"Found gamepad: {device.get('product_string', 'Unknown')}")
return device
logging.error(
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and product ID 0x{self.product_id:04X} found"
)
return None
def start(self):
"""Connect to the gamepad using HIDAPI."""
import hid
self.device_info = self.find_device()
if not self.device_info:
self.running = False
return
try:
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
self.device = hid.device()
self.device.open_path(self.device_info["path"])
self.device.set_nonblocking(1)
manufacturer = self.device.get_manufacturer_string()
product = self.device.get_product_string()
logging.info(f"Connected to {manufacturer} {product}")
logging.info("Gamepad controls (HID mode):")
logging.info(" Left analog stick: Move in X-Y plane")
logging.info(" Right analog stick: Move in Z axis (vertical)")
logging.info(" Button 1/B/Circle: Exit")
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
logging.info(" Button 3/X/Square: End episode with FAILURE")
except OSError as e:
logging.error(f"Error opening gamepad: {e}")
logging.error("You might need to run this with sudo/admin privileges on some systems")
self.running = False
def stop(self):
"""Close the HID device connection."""
if self.device:
self.device.close()
self.device = None
def update(self):
"""
Read and process the latest gamepad data.
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
"""
for _ in range(10):
self._update()
def _update(self):
"""Read and process the latest gamepad data."""
if not self.device or not self.running:
return
try:
# Read data from the gamepad
data = self.device.read(64)
# Interpret gamepad data - this will vary by controller model
# These offsets are for the Logitech RumblePad 2
if data and len(data) >= 8:
# Normalize joystick values from 0-255 to -1.0-1.0
self.left_x = (data[1] - 128) / 128.0
self.left_y = (data[2] - 128) / 128.0
self.right_x = (data[3] - 128) / 128.0
self.right_y = (data[4] - 128) / 128.0
# Apply deadzone
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
# Parse button states (byte 5 in the Logitech RumblePad 2)
buttons = data[5]
# Check if RB is pressed then the intervention flag should be set
self.intervention_flag = data[6] in [2, 6, 10, 14]
# Check if RT is pressed
self.open_gripper_command = data[6] in [8, 10, 12]
# Check if LT is pressed
self.close_gripper_command = data[6] in [4, 6, 12]
# Check if Y/Triangle button (bit 7) is pressed for saving
# Check if X/Square button (bit 5) is pressed for failure
# Check if A/Cross button (bit 4) is pressed for rerecording
if buttons & 1 << 7:
self.episode_end_status = "success"
elif buttons & 1 << 5:
self.episode_end_status = "failure"
elif buttons & 1 << 4:
self.episode_end_status = "rerecord_episode"
else:
self.episode_end_status = None
except OSError as e:
logging.error(f"Error reading from gamepad: {e}")
def get_deltas(self):
"""Get the current movement deltas from gamepad state."""
# Calculate deltas - invert as needed based on controller orientation
delta_x = -self.left_y * self.x_step_size # Forward/backward
delta_y = -self.left_x * self.y_step_size # Left/right
delta_z = -self.right_y * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
def should_quit(self):
"""Return True if quit button was pressed."""
return self.quit_requested
def should_save(self):
"""Return True if save button was pressed."""
return self.save_requested
def test_forward_kinematics(robot, fps=10):
logging.info("Testing Forward Kinematics")
timestep = time.perf_counter()
kinematics = RobotKinematics(robot.robot_type)
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
robot.teleop_step()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = kinematics.fk_gripper_tip(joint_positions)
logging.info(f"EE Position: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def test_inverse_kinematics(robot, fps=10):
logging.info("Testing Inverse Kinematics")
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
desired_ee_pos = ee_pos
target_joint_state = RobotKinematics.ik(joint_positions, desired_ee_pos, position_only=True)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Target Joint State: {target_joint_state}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Inverse Kinematics")
kinematics = RobotKinematics(robot.robot_type)
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = kinematics.fk_gripper_tip(joint_positions)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = leader_ee
target_joint_state = kinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Delta End-Effector Control")
timestep = time.perf_counter()
# Initial position capture
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
kinematics = RobotKinematics(robot.robot_type)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = np.diag(np.ones(4))
joint_positions = robot.follower_arms["main"].read("Present_Position")
fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions)
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
# Get leader state for teleoperation
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
# Get current state
# obs = robot.capture_observation()
# joint_positions = obs["observation.state"].cpu().numpy()
joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Calculate delta between leader and follower end-effectors
# Scaling factor can be adjusted for sensitivity
scaling_factor = 1.0
ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05)
# Apply delta to current position
desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0
desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0
desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3]
# Compute joint targets via inverse kinematics
target_joint_state = kinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
)
initial_leader_ee = leader_ee.copy()
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
# Logging
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, fk_func=None):
"""
Control a robot using delta end-effector movements from any input controller.
Args:
robot: Robot instance to control
controller: InputController instance (keyboard, gamepad, etc.)
fps: Control frequency in Hz
bounds: Optional position limits
fk_func: Forward kinematics function to use
"""
if fk_func is None:
fk_func = RobotKinematics.fk_gripper_tip
logging.info(f"Testing Delta End-Effector Control with {controller.__class__.__name__}")
# Initial position capture
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
kinematics = RobotKinematics(robot.robot_type)
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Initialize desired position with current position
desired_ee_pos = np.eye(4) # Identity matrix
timestep = time.perf_counter()
with controller:
while not controller.should_quit() and time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
# Process input events
controller.update()
# Get current robot state
joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas()
# Update desired position
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z
# Apply bounds if provided
if bounds is not None:
desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"])
# Only send commands if there's actual movement
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
# Compute joint targets via inverse kinematics
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_gym_env(env, controller, fps: int = 30):
"""
Control a robot through a gym environment using keyboard inputs.
Args:
env: A gym environment created with make_robot_env
fps: Target control frequency
"""
logging.info("Testing Keyboard Control of Gym Environment")
print("Keyboard controls:")
print(" Arrow keys: Move in X-Y plane")
print(" Shift and Shift_R: Move in Z axis")
print(" ESC: Exit")
# Reset the environment to get initial observation
obs, info = env.reset()
try:
with controller:
while not controller.should_quit():
loop_start_time = time.perf_counter()
# Process input events
controller.update()
# Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas()
# Create the action vector
action = np.array([delta_x, delta_y, delta_z])
# Skip if no movement
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
# Step the environment - pass action as a tensor with intervention flag
action_tensor = torch.from_numpy(action.astype(np.float32))
obs, reward, terminated, truncated, info = env.step((action_tensor, False))
# Log information
logging.info(f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]")
logging.info(f"Reward: {reward}")
# Reset if episode ended
if terminated or truncated:
logging.info("Episode ended, resetting environment")
obs, info = env.reset()
# Maintain target frame rate
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
finally:
# Close the environment
env.close()
if __name__ == "__main__":
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvTransformConfig, HILSerlRobotEnvConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import make_robot_env
init_logging()
parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument(
"--mode",
type=str,
default="keyboard",
choices=[
"keyboard",
"gamepad",
"keyboard_gym",
"gamepad_gym",
"leader_delta",
"leader",
],
help="Control mode to use",
)
parser.add_argument(
"--robot-type",
type=str,
default="so100",
help="Robot type (so100, koch, aloha, etc.)",
)
args = parser.parse_args()
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
robot = make_robot_from_config(robot_config)
if not robot.is_connected:
robot.connect()
# Example bounds
bounds = {
"max": np.array([0.32170487, 0.201285, 0.10273342]),
"min": np.array([0.16631757, -0.08237468, 0.03364977]),
}
try:
# Determine controller type based on mode prefix
controller = None
if args.mode.startswith("keyboard"):
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
elif args.mode.startswith("gamepad"):
if sys.platform == "darwin":
controller = GamepadControllerHID(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
else:
controller = GamepadController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
# Handle mode categories
if args.mode in ["keyboard", "gamepad"]:
# Direct robot control modes
teleoperate_delta_inverse_kinematics(robot, controller, bounds=bounds, fps=10)
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvTransformConfig())
cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
)
cfg.wrapper.ee_action_space_params.use_gamepad = False
cfg.device = "cpu"
env = make_robot_env(cfg, robot)
teleoperate_gym_env(env, controller, fps=cfg.fps)
elif args.mode == "leader_delta":
# Leader-follower modes don't use controllers
teleoperate_delta_inverse_kinematics_with_leader(robot)
elif args.mode == "leader":
teleoperate_inverse_kinematics_with_leader(robot)
finally:
if robot.is_connected:
robot.disconnect()

View File

@@ -1,135 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import time
import cv2
import numpy as np
from lerobot.common.robot_devices.control_utils import is_headless
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
def find_joint_bounds(
robot,
control_time_s=30,
display_cameras=False,
):
if not robot.is_connected:
robot.connect()
start_episode_t = time.perf_counter()
pos_list = []
while True:
observation, action = robot.teleop_step(record_data=True)
# Wait for 5 seconds to stabilize the robot initial position
if time.perf_counter() - start_episode_t < 5:
continue
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s:
max = np.max(np.stack(pos_list), 0)
min = np.min(np.stack(pos_list), 0)
print(f"Max angle position per joint {max}")
print(f"Min angle position per joint {min}")
break
def find_ee_bounds(
robot,
control_time_s=30,
display_cameras=False,
):
if not robot.is_connected:
robot.connect()
start_episode_t = time.perf_counter()
ee_list = []
while True:
observation, action = robot.teleop_step(record_data=True)
# Wait for 5 seconds to stabilize the robot initial position
if time.perf_counter() - start_episode_t < 5:
continue
kinematics = RobotKinematics(robot.robot_type)
joint_positions = robot.follower_arms["main"].read("Present_Position")
print(f"Joint positions: {joint_positions}")
ee_list.append(kinematics.fk_gripper_tip(joint_positions)[:3, 3])
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s:
max = np.max(np.stack(ee_list), 0)
min = np.min(np.stack(ee_list), 0)
print(f"Max ee position {max}")
print(f"Min ee position {min}")
break
if __name__ == "__main__":
# Create argparse for script-specific arguments
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
parser.add_argument(
"--mode",
type=str,
default="joint",
choices=["joint", "ee"],
help="Mode to run the script in. Can be 'joint' or 'ee'.",
)
parser.add_argument(
"--control-time-s",
type=int,
default=30,
help="Time step to use for control.",
)
parser.add_argument(
"--robot-type",
type=str,
default="so100",
help="Robot type (so100, koch, aloha, etc.)",
)
# Only parse known args, leaving robot config args for Hydra if used
args = parser.parse_args()
# Create robot with the appropriate config
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
robot = make_robot_from_config(robot_config)
if args.mode == "joint":
find_joint_bounds(robot, args.control_time_s)
elif args.mode == "ee":
find_ee_bounds(robot, args.control_time_s)
if robot.is_connected:
robot.disconnect()

File diff suppressed because it is too large Load Diff

View File

@@ -1,54 +0,0 @@
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package hil_serl;
// LearnerService: the Actor calls this to push transitions.
// The Learner implements this service.
service LearnerService {
// Actor -> Learner to store transitions
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
rpc StreamParameters(Empty) returns (stream Parameters);
rpc SendTransitions(stream Transition) returns (Empty);
rpc SendInteractions(stream InteractionMessage) returns (Empty);
rpc Ready(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Transition {
TransferState transfer_state = 1;
bytes data = 2;
}
message Parameters {
TransferState transfer_state = 1;
bytes data = 2;
}
message InteractionMessage {
TransferState transfer_state = 1;
bytes data = 2;
}
message Empty {}

View File

@@ -1,46 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: hilserl.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'hilserl.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=275
_globals['_TRANSFERSTATE']._serialized_end=371
_globals['_TRANSITION']._serialized_start=27
_globals['_TRANSITION']._serialized_end=102
_globals['_PARAMETERS']._serialized_start=104
_globals['_PARAMETERS']._serialized_end=179
_globals['_INTERACTIONMESSAGE']._serialized_start=181
_globals['_INTERACTIONMESSAGE']._serialized_end=264
_globals['_EMPTY']._serialized_start=266
_globals['_EMPTY']._serialized_end=273
_globals['_LEARNERSERVICE']._serialized_start=374
_globals['_LEARNERSERVICE']._serialized_end=696
# @@protoc_insertion_point(module_scope)

View File

@@ -1,276 +0,0 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import hilserl_pb2 as hilserl__pb2
GRPC_GENERATED_VERSION = '1.70.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in hilserl_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class LearnerServiceStub(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendInteractionMessage = channel.unary_unary(
'/hil_serl.LearnerService/SendInteractionMessage',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.StreamParameters = channel.unary_stream(
'/hil_serl.LearnerService/StreamParameters',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Parameters.FromString,
_registered_method=True)
self.SendTransitions = channel.stream_unary(
'/hil_serl.LearnerService/SendTransitions',
request_serializer=hilserl__pb2.Transition.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.SendInteractions = channel.stream_unary(
'/hil_serl.LearnerService/SendInteractions',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/hil_serl.LearnerService/Ready',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
class LearnerServiceServicer(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
def SendInteractionMessage(self, request, context):
"""Actor -> Learner to store transitions
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StreamParameters(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendTransitions(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendInteractions(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_LearnerServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
servicer.SendInteractionMessage,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'StreamParameters': grpc.unary_stream_rpc_method_handler(
servicer.StreamParameters,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Parameters.SerializeToString,
),
'SendTransitions': grpc.stream_unary_rpc_method_handler(
servicer.SendTransitions,
request_deserializer=hilserl__pb2.Transition.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'SendInteractions': grpc.stream_unary_rpc_method_handler(
servicer.SendInteractions,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'hil_serl.LearnerService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class LearnerService(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
@staticmethod
def SendInteractionMessage(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/SendInteractionMessage',
hilserl__pb2.InteractionMessage.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def StreamParameters(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/hil_serl.LearnerService/StreamParameters',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.Parameters.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendTransitions(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.LearnerService/SendTransitions',
hilserl__pb2.Transition.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendInteractions(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.LearnerService/SendInteractions',
hilserl__pb2.InteractionMessage.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/Ready',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -1,546 +0,0 @@
# ruff: noqa: N806, N815, N803
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy.spatial.transform import Rotation
def skew_symmetric(w):
"""Creates the skew-symmetric matrix from a 3D vector."""
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
def rodrigues_rotation(w, theta):
"""Computes the rotation matrix using Rodrigues' formula."""
w_hat = skew_symmetric(w)
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
def screw_axis_to_transform(S, theta):
"""Converts a screw axis to a 4x4 transformation matrix."""
S_w = S[:3]
S_v = S[3:]
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
T = np.eye(4)
T[:3, 3] = S_v * theta
elif np.linalg.norm(S_w) == 1: # Rotation and translation
w_hat = skew_symmetric(S_w)
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v
T = np.eye(4)
T[:3, :3] = R
T[:3, 3] = t
else:
raise ValueError("Invalid screw axis parameters")
return T
def pose_difference_se3(pose1, pose2):
"""
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, combining rotation (SO(3)) and translation.
Each 4x4 matrix has the following structure, a 3x3 rotation matrix in the top-left and a 3x1 translation vector in the top-right:
[R11 R12 R13 tx]
[R21 R22 R23 ty]
[R31 R32 R33 tz]
[ 0 0 0 1]
where Rij is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector.
pose1 - pose2
Args:
pose1: A 4x4 numpy array representing the first pose.
pose2: A 4x4 numpy array representing the second pose.
Returns:
A tuple (translation_diff, rotation_diff) where:
- translation_diff is a 3x1 numpy array representing the translational difference.
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
"""
# Extract rotation matrices from poses
R1 = pose1[:3, :3]
R2 = pose2[:3, :3]
# Calculate translational difference
translation_diff = pose1[:3, 3] - pose2[:3, 3]
# Calculate rotational difference using scipy's Rotation library
R_diff = Rotation.from_matrix(R1 @ R2.T)
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
return np.concatenate([translation_diff, rotation_diff])
def se3_error(target_pose, current_pose):
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
R_target = target_pose[:3, :3]
R_current = current_pose[:3, :3]
R_error = R_target @ R_current.T
rot_error = Rotation.from_matrix(R_error).as_rotvec()
return np.concatenate([pos_error, rot_error])
class RobotKinematics:
"""Robot kinematics class supporting multiple robot models."""
# Robot measurements dictionary
ROBOT_MEASUREMENTS = {
"koch": {
"gripper": [0.239, -0.001, 0.024],
"wrist": [0.209, 0, 0.024],
"forearm": [0.108, 0, 0.02],
"humerus": [0, 0, 0.036],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"so100": {
"gripper": [0.320, 0, 0.050],
"wrist": [0.278, 0, 0.050],
"forearm": [0.143, 0, 0.044],
"humerus": [0.031, 0, 0.072],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"moss": {
"gripper": [0.246, 0.013, 0.111],
"wrist": [0.245, 0.002, 0.064],
"forearm": [0.122, 0, 0.064],
"humerus": [0.001, 0.001, 0.063],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
}
def __init__(self, robot_type="so100"):
"""Initialize kinematics for the specified robot type.
Args:
robot_type: String specifying the robot model ("koch", "so100", or "moss")
"""
if robot_type not in self.ROBOT_MEASUREMENTS:
raise ValueError(
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
)
self.robot_type = robot_type
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
# Initialize all transformation matrices and screw axes
self._setup_transforms()
def _create_translation_matrix(self, x=0, y=0, z=0):
"""Create a 4x4 translation matrix."""
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
def _setup_transforms(self):
"""Setup all transformation matrices and screw axes for the robot."""
# Set up rotation matrices (constant across robot types)
# Gripper orientation
self.gripper_X0 = np.array(
[
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, -1, 0, 0],
[0, 0, 0, 1],
]
)
# Wrist orientation
self.wrist_X0 = np.array(
[
[0, -1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
)
# Base orientation
self.base_X0 = np.array(
[
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
]
)
# Gripper
# Screw axis of gripper frame wrt base frame
self.S_BG = np.array(
[
1,
0,
0,
0,
self.measurements["gripper"][2],
-self.measurements["gripper"][1],
]
)
# Gripper origin to centroid transform
self.X_GoGc = self._create_translation_matrix(x=0.07)
# Gripper origin to tip transform
self.X_GoGt = self._create_translation_matrix(x=0.12)
# 0-position gripper frame pose wrt base
self.X_BoGo = self._create_translation_matrix(
x=self.measurements["gripper"][0],
y=self.measurements["gripper"][1],
z=self.measurements["gripper"][2],
)
# Wrist
# Screw axis of wrist frame wrt base frame
self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]])
# 0-position origin to centroid transform
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
# 0-position wrist frame pose wrt base
self.X_BR = self._create_translation_matrix(
x=self.measurements["wrist"][0],
y=self.measurements["wrist"][1],
z=self.measurements["wrist"][2],
)
# Forearm
# Screw axis of forearm frame wrt base frame
self.S_BF = np.array(
[
0,
1,
0,
-self.measurements["forearm"][2],
0,
self.measurements["forearm"][0],
]
)
# Forearm origin + centroid transform
self.X_FoFc = self._create_translation_matrix(x=0.036) # spellchecker:disable-line
# 0-position forearm frame pose wrt base
self.X_BF = self._create_translation_matrix(
x=self.measurements["forearm"][0],
y=self.measurements["forearm"][1],
z=self.measurements["forearm"][2],
)
# Humerus
# Screw axis of humerus frame wrt base frame
self.S_BH = np.array(
[
0,
-1,
0,
self.measurements["humerus"][2],
0,
-self.measurements["humerus"][0],
]
)
# Humerus origin to centroid transform
self.X_HoHc = self._create_translation_matrix(x=0.0475)
# 0-position humerus frame pose wrt base
self.X_BH = self._create_translation_matrix(
x=self.measurements["humerus"][0],
y=self.measurements["humerus"][1],
z=self.measurements["humerus"][2],
)
# Shoulder
# Screw axis of shoulder frame wrt Base frame
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
# Shoulder origin to centroid transform
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
# 0-position shoulder frame pose wrt base
self.X_BS = self._create_translation_matrix(
x=self.measurements["shoulder"][0],
y=self.measurements["shoulder"][1],
z=self.measurements["shoulder"][2],
)
# Base
# Base origin to centroid transform
self.X_BoBc = self._create_translation_matrix(y=0.015)
# World to base transform
self.X_WoBo = self._create_translation_matrix(
x=self.measurements["base"][0],
y=self.measurements["base"][1],
z=self.measurements["base"][2],
)
# Pre-compute gripper post-multiplication matrix
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
def fk_base(self):
"""Forward kinematics for the base frame."""
return self.X_WoBo @ self.X_BoBc @ self.base_X0
def fk_shoulder(self, robot_pos_deg):
"""Forward kinematics for the shoulder frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS
def fk_humerus(self, robot_pos_deg):
"""Forward kinematics for the humerus frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ self.X_HoHc
@ self.X_BH
)
def fk_forearm(self, robot_pos_deg):
"""Forward kinematics for the forearm frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ self.X_FoFc # spellchecker:disable-line
@ self.X_BF
)
def fk_wrist(self, robot_pos_deg):
"""Forward kinematics for the wrist frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ self.X_RoRc
@ self.X_BR
@ self.wrist_X0
)
def fk_gripper(self, robot_pos_deg):
"""Forward kinematics for the gripper frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
@ self._fk_gripper_post
)
def fk_gripper_tip(self, robot_pos_deg):
"""Forward kinematics for the gripper tip frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
@ self.X_GoGt
@ self.X_BoGo
@ self.gripper_X0
)
def compute_jacobian(self, robot_pos_deg, fk_func=None):
"""Finite differences to compute the Jacobian.
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
if fk_func is None:
fk_func = self.fk_gripper
eps = 1e-8
jac = np.zeros(shape=(6, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
pose_difference_se3(
fk_func(robot_pos_deg[:-1] + delta),
fk_func(robot_pos_deg[:-1] - delta),
)
/ eps
)
jac[:, el_ix] = Sdot
return jac
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
"""Finite differences to compute the positional Jacobian.
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
if fk_func is None:
fk_func = self.fk_gripper
eps = 1e-8
jac = np.zeros(shape=(3, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
) / eps
jac[:, el_ix] = Sdot
return jac
def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None):
"""Inverse kinematics using gradient descent.
Args:
current_joint_state: Initial joint positions in degrees
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
position_only: If True, only match end-effector position, not orientation
fk_func: Forward kinematics function to use (defaults to fk_gripper)
Returns:
Joint positions in degrees that achieve the desired end-effector pose
"""
if fk_func is None:
fk_func = self.fk_gripper
# Do gradient descent.
max_iterations = 5
learning_rate = 1
for _ in range(max_iterations):
current_ee_pose = fk_func(current_joint_state)
if not position_only:
error = se3_error(desired_ee_pose, current_ee_pose)
jac = self.compute_jacobian(current_joint_state, fk_func)
else:
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
delta_angles = np.linalg.pinv(jac) @ error
current_joint_state[:-1] += learning_rate * delta_angles
if np.linalg.norm(error) < 5e-3:
return current_joint_state
return current_joint_state
if __name__ == "__main__":
import time
def run_test(robot_type):
"""Run test suite for a specific robot type."""
print(f"\n--- Testing {robot_type.upper()} Robot ---")
# Initialize kinematics for this robot
robot = RobotKinematics(robot_type)
# Test 1: Forward kinematics consistency
print("Test 1: Forward kinematics consistency")
test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees
# Calculate FK for different joints
shoulder_pose = robot.fk_shoulder(test_angles)
humerus_pose = robot.fk_humerus(test_angles)
forearm_pose = robot.fk_forearm(test_angles)
wrist_pose = robot.fk_wrist(test_angles)
gripper_pose = robot.fk_gripper(test_angles)
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
distances = [
np.linalg.norm(shoulder_pose[:3, 3]),
np.linalg.norm(humerus_pose[:3, 3]),
np.linalg.norm(forearm_pose[:3, 3]),
np.linalg.norm(wrist_pose[:3, 3]),
np.linalg.norm(gripper_pose[:3, 3]),
np.linalg.norm(gripper_tip_pose[:3, 3]),
]
# Check if distances generally increase along the chain
is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1))
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}")
# Test 2: Jacobian computation
print("Test 2: Jacobian computation")
jacobian = robot.compute_jacobian(test_angles)
positional_jacobian = robot.compute_positional_jacobian(test_angles)
# Check shapes
jacobian_shape_ok = jacobian.shape == (6, 5)
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}")
# Test 3: Inverse kinematics
print("Test 3: Inverse kinematics (position only)")
# Generate target pose from known joint angles
original_angles = np.array([10, 20, 30, -10, 5, 0])
target_pose = robot.fk_gripper(original_angles)
# Start IK from a different position
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
# Measure IK performance
start_time = time.time()
computed_angles = robot.ik(initial_guess.copy(), target_pose)
ik_time = time.time() - start_time
# Compute resulting pose from IK solution
result_pose = robot.fk_gripper(computed_angles)
# Calculate position error
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
passed = pos_error < 0.01 # Accept errors less than 1cm
print(f" IK computation time: {ik_time:.4f} seconds")
print(f" Position error: {pos_error:.4f}")
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
# Run tests for all robot types
results = {}
for robot_type in ["koch", "so100", "moss"]:
results[robot_type] = run_test(robot_type)
# Print overall summary
print("\n=== Test Summary ===")
all_passed = all(results.values())
for robot_type, passed in results.items():
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")

File diff suppressed because it is too large Load Diff

View File

@@ -1,80 +0,0 @@
import logging
from multiprocessing import Event, Queue
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def __init__(
self,
shutdown_event: Event, # type: ignore
parameters_queue: Queue,
seconds_between_pushes: float,
transition_queue: Queue,
interaction_message_queue: Queue,
):
self.shutdown_event = shutdown_event
self.parameters_queue = parameters_queue
self.seconds_between_pushes = seconds_between_pushes
self.transition_queue = transition_queue
self.interaction_message_queue = interaction_message_queue
def StreamParameters(self, request, context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
while not self.shutdown_event.is_set():
logging.info("[LEARNER] Push parameters to the Actor")
buffer = self.parameters_queue.get()
yield from send_bytes_in_chunks(
buffer,
hilserl_pb2.Parameters,
log_prefix="[LEARNER] Sending parameters",
silent=True,
)
logging.info("[LEARNER] Parameters sent")
self.shutdown_event.wait(self.seconds_between_pushes)
logging.info("[LEARNER] Stream parameters finished")
return hilserl_pb2.Empty()
def SendTransitions(self, request_iterator, _context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
receive_bytes_in_chunks(
request_iterator,
self.transition_queue,
self.shutdown_event,
log_prefix="[LEARNER] transitions",
)
logging.debug("[LEARNER] Finished receiving transitions")
return hilserl_pb2.Empty()
def SendInteractions(self, request_iterator, _context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive interactions from the Actor")
receive_bytes_in_chunks(
request_iterator,
self.interaction_message_queue,
self.shutdown_event,
log_prefix="[LEARNER] interactions",
)
logging.debug("[LEARNER] Finished receiving interactions")
return hilserl_pb2.Empty()
def Ready(self, request, context): # noqa: N802
return hilserl_pb2.Empty()

View File

@@ -1,142 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event, Queue
from typing import Any
import torch
from lerobot.scripts.server import hilserl_pb2
from lerobot.scripts.server.utils import Transition
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
def bytes_buffer_size(buffer: io.BytesIO) -> int:
buffer.seek(0, io.SEEK_END)
result = buffer.tell()
buffer.seek(0)
return result
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging_method = logging.info if not silent else logging.debug
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
bytes_buffer = io.BytesIO()
step = 0
logging.info(f"{log_prefix} Starting receiver")
for item in iterator:
logging.debug(f"{log_prefix} Received item")
if shutdown_event.is_set():
logging.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step 0")
step = 0
continue
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logging.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
queue.put(bytes_buffer.getvalue())
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
step = 0
logging.debug(f"{log_prefix} Queue updated")
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer.getvalue()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer, weights_only=False) # nosec B614: Using weights_only=False relies on pickle which has security implications.
# This is currently safe as we only deserialize trusted internal data.
# TODO: Verify if weights_only=True would work for our use case (safer default in torch 2.6+)
def python_object_to_bytes(python_object: Any) -> bytes:
return pickle.dumps(python_object)
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
# Add validation checks here
return obj
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
transitions = torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
# Add validation checks here
return transitions
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()

View File

@@ -1,141 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
import sys
from queue import Empty
from typing import TypedDict
import torch
from torch.multiprocessing import Queue
shutdown_event_counter = 0
def setup_process_handlers(use_threads: bool) -> any:
if use_threads:
from threading import Event
else:
from multiprocessing import Event
shutdown_event = Event()
# Define signal handler
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
global shutdown_event_counter
shutdown_event_counter += 1
if shutdown_event_counter > 1:
logging.info("Force shutdown")
sys.exit(1)
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
return shutdown_event
def get_last_item_from_queue(queue: Queue):
item = queue.get()
counter = 1
# Drain queue and keep only the most recent parameters
try:
while True:
item = queue.get_nowait()
counter += 1
except Empty:
pass
logging.debug(f"Drained {counter} items from queue")
return item
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: float
next_state: dict[str, torch.Tensor]
done: bool
truncated: bool
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
device = torch.device(device)
non_blocking = device.type == "cuda"
# Move state tensors to device
transition["state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
}
# Move action to device
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
# Move reward and done if they are tensors
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
# Move next_state tensors to device
transition["next_state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
}
# Move complementary_info tensors if present
if transition.get("complementary_info") is not None:
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
return transition
def move_state_dict_to_device(state_dict, device="cpu"):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
"""
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
else:
return state_dict

View File

@@ -63,13 +63,13 @@ dependencies = [
"opencv-python-headless>=4.9.0",
"packaging>=24.2",
"av>=14.2.0",
"pymunk>=6.6.0",
"pymunk>=6.6.0,<7.0.0",
"pynput>=1.7.7",
"pyzmq>=26.2.1",
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"torch>=2.2.1,<2.7",
"torchcodec==0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torch>=2.2.1",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchvision>=0.21.0",
"wandb>=0.16.3",
"zarr>=2.17.0",
@@ -84,9 +84,9 @@ dora = [
]
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
hilserl = ["transformers>=4.48", "gym-hil>=0.1.3", "protobuf>=5.29.3", "grpcio>=1.70.0"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
pi0 = ["transformers>=4.48.0"]
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
stretch = [
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
@@ -99,40 +99,13 @@ umi = ["imagecodecs>=2024.1.1"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
[tool.poetry]
requires-poetry = ">=2.1"
[tool.ruff]
line-length = 110
target-version = "py310"
exclude = [
"tests/data",
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
"*_pb2.py",
"*_pb2_grpc.py",
]
exclude = ["tests/artifacts/**/*.safetensors"]
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]

156
realman.md Normal file
View File

@@ -0,0 +1,156 @@
# Install
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
```
Install 🤗 LeRobot:
```bash
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install edge-tts
sudo apt install mpv -y
# pip uninstall numpy
# pip install numpy==1.26.0
# pip install pynput
```
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
```bash
conda install ffmpeg=7.1.1 -c conda-forge
# pip uninstall opencv-python
# conda install "opencv>=4.10.0"
```
Install Realman SDK:
```bash
pip install Robotic_Arm==1.0.4.1
pip install pygame
```
# piper集成lerobot
见lerobot_piper_tutorial/1. 🤗 LeRobot新增机械臂的一般流程.pdf
# Teleoperate
```bash
cd piper_scripts/
bash can_activate.sh can0 1000000
cd ..
python lerobot/scripts/control_robot.py \
--robot.type=piper \
--robot.inference_time=false \
--control.type=teleoperate
```
# Record
Set dataset root path
```bash
HF_USER=$PWD/data
echo $HF_USER
```
```bash
python lerobot/scripts/control_robot.py \
--robot.type=realman \
--robot.inference_time=false \
--control.type=record \
--control.fps=15 \
--control.single_task="move" \
--control.repo_id=maic/test \
--control.num_episodes=2 \
--control.warmup_time_s=2 \
--control.episode_time_s=10 \
--control.reset_time_s=10 \
--control.play_sounds=true \
--control.push_to_hub=false \
--control.display_data=true
```
Press right arrow -> at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
Press left arrow <- at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
Press escape ESC at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
# visualize
```bash
python lerobot/scripts/visualize_dataset.py \
--repo-id ${HF_USER}/test \
--episode-index 0
```
# Replay
```bash
python lerobot/scripts/control_robot.py \
--robot.type=piper \
--robot.inference_time=false \
--control.type=replay \
--control.fps=30 \
--control.repo_id=${HF_USER}/test \
--control.episode=0
```
# Caution
1. In lerobots/common/datasets/video_utils, the vcodec is set to **libopenh264**, please find your vcodec by **ffmpeg -codecs**
# Train
具体的训练流程见lerobot_piper_tutorial/2. 🤗 AutoDL训练.pdf
```bash
python lerobot/scripts/train.py \
--dataset.repo_id=${HF_USER}/jack \
--policy.type=act \
--output_dir=outputs/train/act_jack \
--job_name=act_jack \
--device=cuda \
--wandb.enable=true
```
# FT smolvla
python lerobot/scripts/train.py \
--dataset.repo_id=maic/move_the_bottle_into_ultrasonic_device_with_realman_single \
--policy.path=lerobot/smolvla_base \
--output_dir=outputs/train/smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
--job_name=smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
--policy.device=cuda \
--wandb.enable=false \
--steps=200000 \
--batch_size=16
# Inference
还是使用control_robot.py中的record loop配置 **--robot.inference_time=true** 可以将手柄移出。
```bash
python lerobot/scripts/control_robot.py \
--robot.type=realman \
--robot.inference_time=true \
--control.type=record \
--control.fps=30 \
--control.single_task="move the bottle into ultrasonic device with realman single" \
--control.repo_id=maic/move_the_bottle_into_ultrasonic_device_with_realman_single \
--control.num_episodes=1 \
--control.warmup_time_s=2 \
--control.episode_time_s=30 \
--control.reset_time_s=10 \
--control.push_to_hub=false \
--control.policy.path=outputs/train/act_move_the_bottle_into_ultrasonic_device_with_realman_single/checkpoints/100000/pretrained_model
```
```bash
python lerobot/scripts/control_robot.py \
--robot.type=realman \
--robot.inference_time=true \
--control.type=record \
--control.fps=30 \
--control.single_task="move the bottle into ultrasonic device with realman single" \
--control.repo_id=maic/eval_smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
--control.num_episodes=1 \
--control.warmup_time_s=2 \
--control.episode_time_s=60 \
--control.reset_time_s=10 \
--control.push_to_hub=false \
--control.policy.path=outputs/train/smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single/checkpoints/160000/pretrained_model \
--control.display_data=true
```

View File

@@ -0,0 +1,31 @@
from Robotic_Arm.rm_robot_interface import *
armleft = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
armright = RoboticArm()
lefthandle = armleft.rm_create_robot_arm("169.254.128.18", 8080)
print("机械臂ID", lefthandle.id)
righthandle = armright.rm_create_robot_arm("169.254.128.19", 8080)
print("机械臂ID", righthandle.id)
# software_info = armleft.rm_get_arm_software_info()
# if software_info[0] == 0:
# print("\n================== Arm Software Information ==================")
# print("Arm Model: ", software_info[1]['product_version'])
# print("Algorithm Library Version: ", software_info[1]['algorithm_info']['version'])
# print("Control Layer Software Version: ", software_info[1]['ctrl_info']['version'])
# print("Dynamics Version: ", software_info[1]['dynamic_info']['model_version'])
# print("Planning Layer Software Version: ", software_info[1]['plan_info']['version'])
# print("==============================================================\n")
# else:
# print("\nFailed to get arm software information, Error code: ", software_info[0], "\n")
print("Left: ", armleft.rm_get_current_arm_state())
print("Left: ", armleft.rm_get_arm_all_state())
armleft.rm_movej_p()
# print("Right: ", armright.rm_get_current_arm_state())
# 断开所有连接,销毁线程
RoboticArm.rm_destory()

View File

@@ -0,0 +1,15 @@
from Robotic_Arm.rm_robot_interface import *
import time
# 实例化RoboticArm类
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
# 创建机械臂连接打印连接id
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
print(handle.id)
print(arm.rm_movep_follow([-0.330512, 0.255993, -0.161205, 3.141, 0.0, -1.57]))
time.sleep(2)
# print(arm.rm_movep_follow([0.3, 0, 0.3, 3.14, 0, 0]))
# time.sleep(2)
arm.rm_delete_robot_arm()

View File

View File

@@ -0,0 +1,4 @@
__pycache__/
*.pyc
*.pyo
*.pt

View File

@@ -0,0 +1,33 @@
[tool.poetry]
name = "shadow_camera"
version = "0.1.0"
description = "camera class, currently includes realsense"
readme = "README.md"
authors = ["Shadow <qiuchengzhan@gmail.com>"]
license = "MIT"
#include = ["realman_vision/pytransform/_pytransform.so",]
classifiers = [
"Operating System :: POSIX :: Linux amd64",
"Programming Language :: Python :: 3.10",
]
[tool.poetry.dependencies]
python = ">=3.9"
numpy = ">=2.0.1"
opencv-python = ">=4.10.0.84"
pyrealsense2 = ">=2.55.1.6486"
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
pytest = ">=8.3"
black = ">=24.10.0"
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
[tool.poetry.group.dev.dependencies]
[build-system]
requires = ["poetry-core>=1.8.4"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,38 @@
from abc import ABCMeta, abstractmethod
class BaseCamera(metaclass=ABCMeta):
"""摄像头基类"""
def __init__(self):
pass
@abstractmethod
def start_camera(self):
"""启动相机"""
pass
@abstractmethod
def stop_camera(self):
"""停止相机"""
pass
@abstractmethod
def set_resolution(self, resolution_width, resolution_height):
"""设置相机彩色图像分辨率"""
pass
@abstractmethod
def set_frame_rate(self, fps):
"""设置相机彩色图像帧率"""
pass
@abstractmethod
def read_frame(self):
"""读取一帧彩色图像和深度图像"""
pass
@abstractmethod
def get_camera_intrinsics(self):
"""获取彩色图像和深度图像的内参"""
pass

View File

@@ -0,0 +1,38 @@
from shadow_camera import base_camera
import cv2
class OpenCVCamera(base_camera.BaseCamera):
"""基于OpenCV的摄像头类"""
def __init__(self, device_id=0):
"""初始化视频捕获
参数:
device_id: 摄像头设备ID
"""
self.cap = cv2.VideoCapture(device_id)
def get_frame(self):
"""获取当前帧
返回:
frame: 当前帧的图像数据,取不到时返回None
"""
ret, frame = self.cap.read()
return frame if ret else None
def get_frame_info(self):
"""获取当前帧信息
返回:
dict: 帧信息字典
"""
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
channels = int(self.cap.get(cv2.CAP_PROP_FRAME_CHANNELS))
return {
'width': width,
'height': height,
'channels': channels
}

View File

@@ -0,0 +1,280 @@
import time
import logging
import numpy as np
import pyrealsense2 as rs
import base_camera
# 设置日志配置
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class RealSenseCamera(base_camera.BaseCamera):
"""Intel RealSense相机类"""
def __init__(self, serial_num=None, is_depth_frame=False):
"""
初始化相机对象
:param serial_num: 相机序列号默认为None
"""
super().__init__()
self._color_resolution = [640, 480]
self._depth_resolution = [640, 480]
self._color_frames_rate = 30
self._depth_frames_rate = 15
self.timestamp = 0
self.color_timestamp = 0
self.depth_timestamp = 0
self._colorizer = rs.colorizer()
self._config = rs.config()
self.is_depth_frame = is_depth_frame
self.camera_on = False
self.serial_num = serial_num
def get_serial_num(self):
serial_num = {}
context = rs.context()
devices = context.query_devices() # 获取所有设备
if len(context.devices) > 0:
for i, device in enumerate(devices):
serial_num[i] = device.get_info(rs.camera_info.serial_number)
logging.info(f"Detected serial numbers: {serial_num}")
return serial_num
def _set_config(self):
if self.serial_num is not None:
logging.info(f"Setting device with serial number: {self.serial_num}")
self._config.enable_device(self.serial_num)
self._config.enable_stream(
rs.stream.color,
self._color_resolution[0],
self._color_resolution[1],
rs.format.rgb8,
self._color_frames_rate,
)
if self.is_depth_frame:
self._config.enable_stream(
rs.stream.depth,
self._depth_resolution[0],
self._depth_resolution[1],
rs.format.z16,
self._depth_frames_rate,
)
def start_camera(self):
"""
启动相机并获取内参信息,如果后续调用帧对齐,则内参均为彩色内参
"""
self._pipeline = rs.pipeline()
if self.is_depth_frame:
self.point_cloud = rs.pointcloud()
self._align = rs.align(rs.stream.color)
self._set_config()
self.profile = self._pipeline.start(self._config)
if self.is_depth_frame:
self._depth_intrinsics = (
self.profile.get_stream(rs.stream.depth)
.as_video_stream_profile()
.get_intrinsics()
)
self._color_intrinsics = (
self.profile.get_stream(rs.stream.color)
.as_video_stream_profile()
.get_intrinsics()
)
self.camera_on = True
logging.info("Camera started successfully")
logging.info(
f"Camera started with color resolution: {self._color_resolution}, depth resolution: {self._depth_resolution}"
)
logging.info(
f"Color FPS: {self._color_frames_rate}, Depth FPS: {self._depth_frames_rate}"
)
def stop_camera(self):
"""
停止相机
"""
self._pipeline.stop()
self.camera_on = False
logging.info("Camera stopped")
def set_resolution(self, color_resolution, depth_resolution):
self._color_resolution = color_resolution
self._depth_resolution = depth_resolution
logging.info(
"Optional color resolution:"
"[320, 180] [320, 240] [424, 240] [640, 360] [640, 480]"
"[848, 480] [960, 540] [1280, 720] [1920, 1080]"
)
logging.info(
"Optional depth resolution:"
"[256, 144] [424, 240] [480, 270] [640, 360] [640, 400]"
"[640, 480] [848, 100] [848, 480] [1280, 720] [1280, 800]"
)
logging.info(f"Set color resolution to: {color_resolution}")
logging.info(f"Set depth resolution to: {depth_resolution}")
def set_frame_rate(self, color_fps, depth_fps):
self._color_frames_rate = color_fps
self._depth_frames_rate = depth_fps
logging.info("Optional color fps: 6 15 30 60 ")
logging.info("Optional depth fps: 6 15 30 60 90 100 300")
logging.info(f"Set color FPS to: {color_fps}")
logging.info(f"Set depth FPS to: {depth_fps}")
# TODO: 调节白平衡进行补偿
# def set_exposure(self, exposure):
def read_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
"""
读取一帧彩色图像和深度图像
:return: 彩色图像和深度图像的NumPy数组
"""
while not self.camera_on:
time.sleep(0.5)
color_image = None
depth_image = None
colorized_depth = None
point_cloud = None
try:
frames = self._pipeline.wait_for_frames()
if is_color:
color_frame = frames.get_color_frame()
color_image = np.asanyarray(color_frame.get_data())
else:
color_image = None
if is_depth:
depth_frame = frames.get_depth_frame()
depth_image = np.asanyarray(depth_frame.get_data())
else:
depth_image = None
colorized_depth = (
np.asanyarray(self._colorizer.colorize(depth_frame).get_data())
if is_colorized_depth
else None
)
point_cloud = (
np.asanyarray(self.point_cloud.calculate(depth_frame).get_vertices())
if is_point_cloud
else None
)
# 获取时间戳单位为ms对齐后color时间戳 > depth = aligned选择color
self.color_timestamp = color_frame.get_timestamp()
if self.is_depth_frame:
self.depth_timestamp = depth_frame.get_timestamp()
except Exception as e:
logging.warning(e)
if "Frame didn't arrive within 5000" in str(e):
logging.warning("Frame didn't arrive within 5000ms, resetting device")
self.stop_camera()
self.start_camera()
return color_image, depth_image, colorized_depth, point_cloud
def read_align_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
"""
读取一帧对齐的彩色图像和深度图像
:return: 彩色图像和深度图像的NumPy数组
"""
while not self.camera_on:
time.sleep(0.5)
try:
frames = self._pipeline.wait_for_frames()
aligned_frames = self._align.process(frames)
aligned_color_frame = aligned_frames.get_color_frame()
self._aligned_depth_frame = aligned_frames.get_depth_frame()
color_image = np.asanyarray(aligned_color_frame.get_data())
depth_image = np.asanyarray(self._aligned_depth_frame.get_data())
colorized_depth = (
np.asanyarray(
self._colorizer.colorize(self._aligned_depth_frame).get_data()
)
if is_colorized_depth
else None
)
if is_point_cloud:
points = self.point_cloud.calculate(self._aligned_depth_frame)
# 将元组数据转换为 NumPy 数组
point_cloud = np.array(
[[point[0], point[1], point[2]] for point in points.get_vertices()]
)
else:
point_cloud = None
# 获取时间戳单位为ms对齐后color时间戳 > depth = aligned选择color
self.timestamp = aligned_color_frame.get_timestamp()
return color_image, depth_image, colorized_depth, point_cloud
except Exception as e:
if "Frame didn't arrive within 5000" in str(e):
logging.warning("Frame didn't arrive within 5000ms, resetting device")
self.stop_camera()
self.start_camera()
# device = self.profile.get_device()
# device.hardware_reset()
def get_camera_intrinsics(self):
"""
获取彩色图像和深度图像的内参信息
:return: 彩色图像和深度图像的内参信息
"""
# 宽高:.width, .height; 焦距:.fx, .fy; 像素坐标:.ppx, .ppy; 畸变系数:.coeffs
logging.info("Getting camera intrinsics")
logging.info(
"Width and height: .width, .height; Focal length: .fx, .fy; Pixel coordinates: .ppx, .ppy; Distortion coefficient: .coeffs"
)
return self._color_intrinsics, self._depth_intrinsics
def get_3d_camera_coordinate(self, depth_pixel, align=False):
"""
获取深度相机坐标系下的三维坐标
:param depth_pixel:深度像素坐标
:param align: 是否对齐
:return: 深度值和相机坐标
"""
if not hasattr(self, "_aligned_depth_frame"):
raise AttributeError(
"Aligned depth frame not set. Call read_align_frame() first."
)
distance = self._aligned_depth_frame.get_distance(
depth_pixel[0], depth_pixel[1]
)
intrinsics = self._color_intrinsics if align else self._depth_intrinsics
camera_coordinate = rs.rs2_deproject_pixel_to_point(
intrinsics, depth_pixel, distance
)
return distance, camera_coordinate
if __name__ == "__main__":
camera = RealSenseCamera(is_depth_frame=False)
camera.get_serial_num()
camera.start_camera()
# camera.set_frame_rate(60, 60)
color_image, depth_image, colorized_depth, point_cloud = camera.read_frame()
camera.stop_camera()
logging.info(f"Color image shape: {color_image.shape}")
# logging.info(f"Depth image shape: {depth_image.shape}")
# logging.info(f"Colorized depth image shape: {colorized_depth.shape}")
# logging.info(f"Point cloud shape: {point_cloud.shape}")
logging.info(f"Color timestamp: {camera.timestamp}")
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
logging.info(f"Color timestamp: {camera.color_timestamp}")
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
logging.info("Test passed")

View File

@@ -0,0 +1,101 @@
import pyrealsense2 as rs
import numpy as np
import h5py
import time
import threading
import keyboard # 用于监听键盘输入
# 全局变量
is_recording = False # 标志位,控制录制状态
color_images = [] # 存储彩色图像
depth_images = [] # 存储深度图像
timestamps = [] # 存储时间戳
# 配置D435相机
def configure_camera():
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) # 彩色图像流
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) # 深度图像流
pipeline.start(config)
return pipeline
# 监听键盘输入,控制录制状态
def listen_for_keyboard():
global is_recording
while True:
if keyboard.is_pressed('s'): # 按下 's' 开始录制
is_recording = True
print("Recording started.")
time.sleep(0.5) # 防止重复触发
elif keyboard.is_pressed('q'): # 按下 'q' 停止录制
is_recording = False
print("Recording stopped.")
time.sleep(0.5) # 防止重复触发
elif keyboard.is_pressed('e'): # 按下 'e' 退出程序
print("Exiting program.")
exit()
time.sleep(0.1)
# 采集图像数据
def capture_frames(pipeline):
global is_recording, color_images, depth_images, timestamps
try:
while True:
if is_recording:
frames = pipeline.wait_for_frames()
color_frame = frames.get_color_frame()
depth_frame = frames.get_depth_frame()
if not color_frame or not depth_frame:
continue
# 获取当前时间戳
timestamp = time.time()
# 将图像转换为numpy数组
color_image = np.asanyarray(color_frame.get_data())
depth_image = np.asanyarray(depth_frame.get_data())
# 存储数据
color_images.append(color_image)
depth_images.append(depth_image)
timestamps.append(timestamp)
print(f"Captured frame at {timestamp}")
else:
time.sleep(0.1) # 如果未录制,等待一段时间
finally:
pipeline.stop()
# 保存为HDF5文件
def save_to_hdf5(color_images, depth_images, timestamps, filename="output.h5"):
with h5py.File(filename, "w") as f:
f.create_dataset("color_images", data=np.array(color_images), compression="gzip")
f.create_dataset("depth_images", data=np.array(depth_images), compression="gzip")
f.create_dataset("timestamps", data=np.array(timestamps), compression="gzip")
print(f"Data saved to {filename}")
# 主函数
def main():
global is_recording, color_images, depth_images, timestamps
# 启动键盘监听线程
keyboard_thread = threading.Thread(target=listen_for_keyboard)
keyboard_thread.daemon = True
keyboard_thread.start()
# 配置相机
pipeline = configure_camera()
# 开始采集图像
capture_frames(pipeline)
# 录制结束后保存数据
if color_images and depth_images and timestamps:
save_to_hdf5(color_images, depth_images, timestamps, "mobile_aloha_data.h5")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,152 @@
import os
import cv2
import time
import numpy as np
from os import path
import pyrealsense2 as rs
from shadow_camera import realsense
import logging
def test_camera():
camera = realsense.RealSenseCamera('241122071186')
camera.start_camera()
while True:
# result = camera.read_align_frame()
# if result is None:
# print('is None')
# continue
# start_time = time.time()
color_image, depth_image, colorized_depth, vtx = camera.read_frame()
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
print(f"color_image: {color_image.shape}")
# print(f"Time: {end_time - start_time}")
cv2.imshow("bgr_image", color_image)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
camera.stop_camera()
def test_get_serial_num():
camera = realsense.RealSenseCamera()
device = camera.get_serial_num()
class CameraCapture:
def __init__(self, camera_serial_num=None, save_dir="./save"):
self._camera_serial_num = camera_serial_num
self._color_save_dir = path.join(save_dir, "color")
self._depth_save_dir = path.join(save_dir, "depth")
os.makedirs(save_dir, exist_ok=True)
os.makedirs(self._color_save_dir, exist_ok=True)
os.makedirs(self._depth_save_dir, exist_ok=True)
def get_serial_num(self):
self._camera_serial_num = {}
camera_names = ["left", "right", "head", "table"]
context = rs.context()
devices = context.query_devices() # 获取所有设备
if len(context.devices) > 0:
for i, device in enumerate(devices):
self._camera_serial_num[camera_names[i]] = device.get_info(
rs.camera_info.serial_number
)
print(self._camera_serial_num)
return self._camera_serial_num
def start_camera(self):
if self._camera_serial_num is None:
self.get_serial_num()
self._camera_left = realsense.RealSenseCamera(self._camera_serial_num["left"])
self._camera_right = realsense.RealSenseCamera(self._camera_serial_num["right"])
self._camera_head = realsense.RealSenseCamera(self._camera_serial_num["head"])
self._camera_left.start_camera()
self._camera_right.start_camera()
self._camera_head.start_camera()
def stop_camera(self):
self._camera_left.stop_camera()
self._camera_right.stop_camera()
self._camera_head.stop_camera()
def _save_datas(self, timestamp, color_image, depth_image, camera_name):
color_filename = path.join(
self._color_save_dir, f"{timestamp}" + camera_name + ".jpg"
)
depth_filename = path.join(
self._depth_save_dir, f"{timestamp}" + camera_name + ".png"
)
cv2.imwrite(color_filename, color_image)
cv2.imwrite(depth_filename, depth_image)
def capture_images(self):
while True:
(
color_image_left,
depth_image_left,
_,
_,
) = self._camera_left.read_align_frame()
(
color_image_right,
depth_image_right,
_,
_,
) = self._camera_right.read_align_frame()
(
color_image_head,
depth_image_head,
_,
point_cloud3,
) = self._camera_head.read_align_frame()
bgr_color_image_left = cv2.cvtColor(color_image_left, cv2.COLOR_RGB2BGR)
bgr_color_image_right = cv2.cvtColor(color_image_right, cv2.COLOR_RGB2BGR)
bgr_color_image_head = cv2.cvtColor(color_image_head, cv2.COLOR_RGB2BGR)
timestamp = time.time() * 1000
cv2.imshow("Camera left", bgr_color_image_left)
cv2.imshow("Camera right", bgr_color_image_right)
cv2.imshow("Camera head", bgr_color_image_head)
# self._save_datas(
# timestamp, bgr_color_image_left, depth_image_left, "left"
# )
# self._save_datas(
# timestamp, bgr_color_image_right, depth_image_right, "right"
# )
# self._save_datas(
# timestamp, bgr_color_image_head, depth_image_head, "head"
# )
if cv2.waitKey(1) & 0xFF == ord("q"):
break
cv2.destroyAllWindows()
if __name__ == "__main__":
#test_camera()
test_get_serial_num()
"""
输入相机序列号制定左右相机:
dict{'left': '241222075132', 'right': '242322076532', 'head': '242322076532'}
保存路径:
str./save
输入为空,自动分配相机序列号(不指定左、右、头部),保存路径为./save
"""
# capture = CameraCapture()
# capture.get_serial_num()
# test_get_serial_num()
# capture.start_camera()
# capture.capture_images()
# capture.stop_camera()

View File

@@ -0,0 +1,71 @@
import pytest
import pyrealsense2 as rs
from shadow_camera.realsense import RealSenseCamera
class TestRealSenseCamera:
@pytest.fixture(autouse=True)
def setup_camera(self):
self.camera = RealSenseCamera()
def test_get_serial_num(self):
serial_nums = self.camera.get_serial_num()
assert isinstance(serial_nums, dict)
assert len(serial_nums) > 0
def test_start_stop_camera(self):
self.camera.start_camera()
assert self.camera.camera_on is True
self.camera.stop_camera()
assert self.camera.camera_on is False
def test_set_resolution(self):
color_resolution = [1280, 720]
depth_resolution = [1280, 720]
self.camera.set_resolution(color_resolution, depth_resolution)
assert self.camera._color_resolution == color_resolution
assert self.camera._depth_resolution == depth_resolution
def test_set_frame_rate(self):
color_fps = 60
depth_fps = 60
self.camera.set_frame_rate(color_fps, depth_fps)
assert self.camera._color_frames_rate == color_fps
assert self.camera._depth_frames_rate == depth_fps
def test_read_frame(self):
self.camera.start_camera()
color_image, depth_image, colorized_depth, point_cloud = (
self.camera.read_frame()
)
assert color_image is not None
assert depth_image is not None
self.camera.stop_camera()
def test_read_align_frame(self):
self.camera.start_camera()
color_image, depth_image, colorized_depth, point_cloud = (
self.camera.read_align_frame()
)
assert color_image is not None
assert depth_image is not None
self.camera.stop_camera()
def test_get_camera_intrinsics(self):
self.camera.start_camera()
color_intrinsics, depth_intrinsics = self.camera.get_camera_intrinsics()
assert color_intrinsics is not None
assert depth_intrinsics is not None
self.camera.stop_camera()
def test_get_3d_camera_coordinate(self):
self.camera.start_camera()
# 先调用 read_align_frame 方法以确保 _aligned_depth_frame 被设置
self.camera.read_align_frame()
depth_pixel = [320, 240]
distance, camera_coordinate = self.camera.get_3d_camera_coordinate(
depth_pixel, align=True
)
assert distance > 0
assert len(camera_coordinate) == 3
self.camera.stop_camera()

View File

@@ -0,0 +1,10 @@
__pycache__/
build/
devel/
dist/
data/
.catkin_workspace
*.pyc
*.pyo
*.pt
.vscode/

View File

@@ -0,0 +1,89 @@
# ACT: Action Chunking with Transformers
### *New*: [ACT tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing)
TL;DR: if your ACT policy is jerky or pauses in the middle of an episode, just train for longer! Success rate and smoothness can improve way after loss plateaus.
#### Project Website: https://tonyzhaozh.github.io/aloha/
This repo contains the implementation of ACT, together with 2 simulated environments:
Transfer Cube and Bimanual Insertion. You can train and evaluate ACT in sim or real.
For real, you would also need to install [ALOHA](https://github.com/tonyzhaozh/aloha).
### Updates:
You can find all scripted/human demo for simulated environments [here](https://drive.google.com/drive/folders/1gPR03v05S1xiInoVJn7G7VJ9pDCnxq9O?usp=share_link).
### Repo Structure
- ``imitate_episodes.py`` Train and Evaluate ACT
- ``policy.py`` An adaptor for ACT policy
- ``detr`` Model definitions of ACT, modified from DETR
- ``sim_env.py`` Mujoco + DM_Control environments with joint space control
- ``ee_sim_env.py`` Mujoco + DM_Control environments with EE space control
- ``scripted_policy.py`` Scripted policies for sim environments
- ``constants.py`` Constants shared across files
- ``utils.py`` Utils such as data loading and helper functions
- ``visualize_episodes.py`` Save videos from a .hdf5 dataset
### Installation
conda create -n aloha python=3.8.10
conda activate aloha
pip install torchvision
pip install torch
pip install pyquaternion
pip install pyyaml
pip install rospkg
pip install pexpect
pip install mujoco==2.3.7
pip install dm_control==1.0.14
pip install opencv-python
pip install matplotlib
pip install einops
pip install packaging
pip install h5py
pip install ipython
cd act/detr && pip install -e .
### Example Usages
To set up a new terminal, run:
conda activate aloha
cd <path to act repo>
### Simulated experiments
We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``.
To generated 50 episodes of scripted data, run:
python3 record_sim_episodes.py \
--task_name sim_transfer_cube_scripted \
--dataset_dir <data save dir> \
--num_episodes 50
To can add the flag ``--onscreen_render`` to see real-time rendering.
To visualize the episode after it is collected, run
python3 visualize_episodes.py --dataset_dir <data save dir> --episode_idx 0
To train ACT:
# Transfer Cube task
python3 imitate_episodes.py \
--task_name sim_transfer_cube_scripted \
--ckpt_dir <ckpt dir> \
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \
--num_epochs 2000 --lr 1e-5 \
--seed 0
To evaluate the policy, run the same command but add ``--eval``. This loads the best validation checkpoint.
The success rate should be around 90% for transfer cube, and around 50% for insertion.
To enable temporal ensembling, add flag ``--temporal_agg``.
Videos will be saved to ``<ckpt_dir>`` for each rollout.
You can also add ``--onscreen_render`` to see real-time rendering during evaluation.
For real-world data where things can be harder to model, train for at least 5000 epochs or 3-4 times the length after the loss has plateaued.
Please refer to [tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing) for more info.

View File

@@ -0,0 +1,74 @@
robot_env: {
# TODO change the path to the correct one
rm_left_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_left_arm.yaml',
rm_right_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_right_arm.yaml',
arm_axis: 6,
head_camera: '215222076892',
bottom_camera: '215222076981',
left_camera: '152122078151',
right_camera: '152122073489',
# init_left_arm_angle: [0.226, 21.180, 91.304, -0.515, 67.486, 2.374, 0.9],
# init_right_arm_angle: [-1.056, 33.057, 84.376, -0.204, 66.357, -3.236, 0.9]
init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
}
dataset_dir: '/home/rm/aloha/shadow_rm_aloha/data/dataset/20250103'
checkpoint_dir: '/home/rm/aloha/shadow_rm_act/data'
# checkpoint_name: 'policy_best.ckpt'
checkpoint_name: 'policy_9500.ckpt'
state_dim: 14
save_episode: True
num_rollouts: 50 #训练期间要收集的 rollout轨迹数量
real_robot: True
policy_class: 'ACT'
onscreen_render: False
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right']
episode_len: 300 #episode 的最大长度(时间步数)。
task_name: 'aloha_01_11.28'
temporal_agg: False #是否使用时间聚合
batch_size: 8 #训练期间每批的样本数。
seed: 1000 #随机种子。
chunk_size: 30 #用于处理序列的块大小
eval_every: 1 #每隔 eval_every 步评估一次模型。
num_steps: 10000 #训练的总步数。
validate_every: 1 #每隔 validate_every 步验证一次模型。
save_every: 500 #每隔 save_every 步保存一次检查点。
load_pretrain: False #是否加载预训练模型。
resume_ckpt_path:
name_filter: # TODO
skip_mirrored_data: False #是否跳过镜像数据(例如用于基于对称性的数据增强)。
stats_dir:
sample_weights:
train_ratio: 0.8 #用于训练的数据比例(其余数据用于验证)
policy_config: {
hidden_dim: 512, # Size of the embeddings (dimension of the transformer)
state_dim: 14, # Dimension of the state
position_embedding: 'sine', # ('sine', 'learned').Type of positional embedding to use on top of the image features
lr_backbone: 1.0e-5,
masks: False, # If true, the model masks the non-visible pixels
backbone: 'resnet18',
dilation: False, # If true, we replace stride with dilation in the last convolutional block (DC5)
dropout: 0.1, # Dropout applied in the transformer
nheads: 8,
dim_feedforward: 3200, # Intermediate size of the feedforward layers in the transformer blocks
enc_layers: 4, # Number of encoding layers in the transformer
dec_layers: 7, # Number of decoding layers in the transformer
pre_norm: False, # If true, apply LayerNorm to the input instead of the output of the MultiheadAttention and FeedForward
num_queries: 30,
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right'],
vq: False,
vq_class: none,
vq_dim: 64,
action_dim: 14,
no_encoder: False,
lr: 1.0e-5,
weight_decay: 1.0e-4,
kl_weight: 10,
# lr_drop: 200,
# clip_max_norm: 0.1,
}

View File

@@ -0,0 +1,267 @@
import numpy as np
import collections
import os
from constants import DT, XML_DIR, START_ARM_POSE
from constants import PUPPET_GRIPPER_POSITION_CLOSE
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
from src.shadow_act.utils.utils import sample_box_pose, sample_insertion_pose
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
import IPython
e = IPython.embed
def make_ee_sim_env(task_name):
"""
Environment for simulated robot bi-manual manipulation, with end-effector control.
Action space: [left_arm_pose (7), # position and quaternion for end effector
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_pose (7), # position and quaternion for end effector
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
if 'sim_transfer_cube' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = TransferCubeEETask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
elif 'sim_insertion' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = InsertionEETask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
else:
raise NotImplementedError
return env
class BimanualViperXEETask(base.Task):
def __init__(self, random=None):
super().__init__(random=random)
def before_step(self, action, physics):
a_len = len(action) // 2
action_left = action[:a_len]
action_right = action[a_len:]
# set mocap position and quat
# left
np.copyto(physics.data.mocap_pos[0], action_left[:3])
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
# right
np.copyto(physics.data.mocap_pos[1], action_right[:3])
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
# set gripper
g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7])
g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7])
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
def initialize_robots(self, physics):
# reset joint position
physics.named.data.qpos[:16] = START_ARM_POSE
# reset mocap to align with end effector
# to obtain these numbers:
# (1) make an ee_sim env and reset to the same start_pose
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
# repeat the same for right side
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
# right
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
# reset gripper control
close_gripper_control = np.array([
PUPPET_GRIPPER_POSITION_CLOSE,
-PUPPET_GRIPPER_POSITION_CLOSE,
PUPPET_GRIPPER_POSITION_CLOSE,
-PUPPET_GRIPPER_POSITION_CLOSE,
])
np.copyto(physics.data.ctrl, close_gripper_control)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
super().initialize_episode(physics)
@staticmethod
def get_qpos(physics):
qpos_raw = physics.data.qpos.copy()
left_qpos_raw = qpos_raw[:8]
right_qpos_raw = qpos_raw[8:16]
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
@staticmethod
def get_qvel(physics):
qvel_raw = physics.data.qvel.copy()
left_qvel_raw = qvel_raw[:8]
right_qvel_raw = qvel_raw[8:16]
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
@staticmethod
def get_env_state(physics):
raise NotImplementedError
def get_observation(self, physics):
# note: it is important to do .copy()
obs = collections.OrderedDict()
obs['qpos'] = self.get_qpos(physics)
obs['qvel'] = self.get_qvel(physics)
obs['env_state'] = self.get_env_state(physics)
obs['images'] = dict()
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
# used in scripted policy to obtain starting pose
obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy()
obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy()
# used when replaying joint trajectory
obs['gripper_ctrl'] = physics.data.ctrl.copy()
return obs
def get_reward(self, physics):
raise NotImplementedError
class TransferCubeEETask(BimanualViperXEETask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
self.initialize_robots(physics)
# randomize box position
cube_pose = sample_box_pose()
box_start_idx = physics.model.name2id('red_box_joint', 'joint')
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
# print(f"randomized cube position to {cube_position}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether left gripper is holding the box
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_table = ("red_box", "table") in all_contact_pairs
reward = 0
if touch_right_gripper:
reward = 1
if touch_right_gripper and not touch_table: # lifted
reward = 2
if touch_left_gripper: # attempted transfer
reward = 3
if touch_left_gripper and not touch_table: # successful transfer
reward = 4
return reward
class InsertionEETask(BimanualViperXEETask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
self.initialize_robots(physics)
# randomize peg and socket position
peg_pose, socket_pose = sample_insertion_pose()
id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
peg_start_id = physics.model.name2id('red_peg_joint', 'joint')
peg_start_idx = id2index(peg_start_id)
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
# print(f"randomized cube position to {cube_position}")
socket_start_id = physics.model.name2id('blue_socket_joint', 'joint')
socket_start_idx = id2index(socket_start_id)
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
# print(f"randomized cube position to {cube_position}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether peg touches the pin
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
peg_touch_table = ("red_peg", "table") in all_contact_pairs
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
("socket-2", "table") in all_contact_pairs or \
("socket-3", "table") in all_contact_pairs or \
("socket-4", "table") in all_contact_pairs
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
("red_peg", "socket-2") in all_contact_pairs or \
("red_peg", "socket-3") in all_contact_pairs or \
("red_peg", "socket-4") in all_contact_pairs
pin_touched = ("red_peg", "pin") in all_contact_pairs
reward = 0
if touch_left_gripper and touch_right_gripper: # touch both
reward = 1
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
reward = 2
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
reward = 3
if pin_touched: # successful insertion
reward = 4
return reward

View File

@@ -0,0 +1,36 @@
[tool.poetry]
name = "shadow_act"
version = "0.1.0"
description = "Embodied data, ACT and other methods; training and verification function packages"
readme = "README.md"
authors = ["Shadow <qiuchengzhan@gmail.com>"]
license = "MIT"
# include = ["realman_vision/pytransform/_pytransform.so",]
classifiers = [
"Operating System :: POSIX :: Linux amd64",
"Programming Language :: Python :: 3.10",
]
[tool.poetry.dependencies]
python = ">=3.9"
wandb = ">=0.18.0"
einops = ">=0.8.0"
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
pytest = ">=8.3"
black = ">=24.10.0"
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
[tool.poetry.group.dev.dependencies]
[build-system]
requires = ["poetry-core>=1.8.4"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,189 @@
import time
import os
import numpy as np
import argparse
import matplotlib.pyplot as plt
import h5py
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS
from ee_sim_env import make_ee_sim_env
from sim_env import make_sim_env, BOX_POSE
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
import IPython
e = IPython.embed
def main(args):
"""
Generate demonstration data in simulation.
First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory.
Replace the gripper joint positions with the commanded joint position.
Replay this joint trajectory (as action sequence) in sim_env, and record all observations.
Save this episode of data, and continue to next episode of data collection.
"""
task_name = args['task_name']
dataset_dir = args['dataset_dir']
num_episodes = args['num_episodes']
onscreen_render = args['onscreen_render']
inject_noise = False
render_cam_name = 'angle'
if not os.path.isdir(dataset_dir):
os.makedirs(dataset_dir, exist_ok=True)
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
camera_names = SIM_TASK_CONFIGS[task_name]['camera_names']
if task_name == 'sim_transfer_cube_scripted':
policy_cls = PickAndTransferPolicy
elif task_name == 'sim_insertion_scripted':
policy_cls = InsertionPolicy
else:
raise NotImplementedError
success = []
for episode_idx in range(num_episodes):
print(f'{episode_idx=}')
print('Rollout out EE space scripted policy')
# setup the environment
env = make_ee_sim_env(task_name)
ts = env.reset()
episode = [ts]
policy = policy_cls(inject_noise)
# setup plotting
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
plt.ion()
for step in range(episode_len):
action = policy(ts)
ts = env.step(action)
episode.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images'][render_cam_name])
plt.pause(0.002)
plt.close()
episode_return = np.sum([ts.reward for ts in episode[1:]])
episode_max_reward = np.max([ts.reward for ts in episode[1:]])
if episode_max_reward == env.task.max_reward:
print(f"{episode_idx=} Successful, {episode_return=}")
else:
print(f"{episode_idx=} Failed")
joint_traj = [ts.observation['qpos'] for ts in episode]
# replace gripper pose with gripper control
gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode]
for joint, ctrl in zip(joint_traj, gripper_ctrl_traj):
left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0])
right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2])
joint[6] = left_ctrl
joint[6+7] = right_ctrl
subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0
# clear unused variables
del env
del episode
del policy
# setup the environment
print('Replaying joint commands')
env = make_sim_env(task_name)
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
ts = env.reset()
episode_replay = [ts]
# setup plotting
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
plt.ion()
for t in range(len(joint_traj)): # note: this will increase episode length by 1
action = joint_traj[t]
ts = env.step(action)
episode_replay.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images'][render_cam_name])
plt.pause(0.02)
episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]])
if episode_max_reward == env.task.max_reward:
success.append(1)
print(f"{episode_idx=} Successful, {episode_return=}")
else:
success.append(0)
print(f"{episode_idx=} Failed")
plt.close()
"""
For each timestep:
observations
- images
- each_cam_name (480, 640, 3) 'uint8'
- qpos (14,) 'float64'
- qvel (14,) 'float64'
action (14,) 'float64'
"""
data_dict = {
'/observations/qpos': [],
'/observations/qvel': [],
'/action': [],
}
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'] = []
# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
# truncate here to be consistent
joint_traj = joint_traj[:-1]
episode_replay = episode_replay[:-1]
# len(joint_traj) i.e. actions: max_timesteps
# len(episode_replay) i.e. time steps: max_timesteps + 1
max_timesteps = len(joint_traj)
while joint_traj:
action = joint_traj.pop(0)
ts = episode_replay.pop(0)
data_dict['/observations/qpos'].append(ts.observation['qpos'])
data_dict['/observations/qvel'].append(ts.observation['qvel'])
data_dict['/action'].append(action)
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
# HDF5
t0 = time.time()
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
root.attrs['sim'] = True
obs = root.create_group('observations')
image = obs.create_group('images')
for cam_name in camera_names:
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
# compression='gzip',compression_opts=2,)
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
qpos = obs.create_dataset('qpos', (max_timesteps, 14))
qvel = obs.create_dataset('qvel', (max_timesteps, 14))
action = root.create_dataset('action', (max_timesteps, 14))
for name, array in data_dict.items():
root[name][...] = array
print(f'Saving: {time.time() - t0:.1f} secs\n')
print(f'Saved to {dataset_dir}')
print(f'Success: {np.sum(success)} / {len(success)}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True)
parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False)
parser.add_argument('--onscreen_render', action='store_true')
main(vars(parser.parse_args()))

View File

@@ -0,0 +1,194 @@
import numpy as np
import matplotlib.pyplot as plt
from pyquaternion import Quaternion
from constants import SIM_TASK_CONFIGS
from ee_sim_env import make_ee_sim_env
import IPython
e = IPython.embed
class BasePolicy:
def __init__(self, inject_noise=False):
self.inject_noise = inject_noise
self.step_count = 0
self.left_trajectory = None
self.right_trajectory = None
def generate_trajectory(self, ts_first):
raise NotImplementedError
@staticmethod
def interpolate(curr_waypoint, next_waypoint, t):
t_frac = (t - curr_waypoint["t"]) / (next_waypoint["t"] - curr_waypoint["t"])
curr_xyz = curr_waypoint['xyz']
curr_quat = curr_waypoint['quat']
curr_grip = curr_waypoint['gripper']
next_xyz = next_waypoint['xyz']
next_quat = next_waypoint['quat']
next_grip = next_waypoint['gripper']
xyz = curr_xyz + (next_xyz - curr_xyz) * t_frac
quat = curr_quat + (next_quat - curr_quat) * t_frac
gripper = curr_grip + (next_grip - curr_grip) * t_frac
return xyz, quat, gripper
def __call__(self, ts):
# generate trajectory at first timestep, then open-loop execution
if self.step_count == 0:
self.generate_trajectory(ts)
# obtain left and right waypoints
if self.left_trajectory[0]['t'] == self.step_count:
self.curr_left_waypoint = self.left_trajectory.pop(0)
next_left_waypoint = self.left_trajectory[0]
if self.right_trajectory[0]['t'] == self.step_count:
self.curr_right_waypoint = self.right_trajectory.pop(0)
next_right_waypoint = self.right_trajectory[0]
# interpolate between waypoints to obtain current pose and gripper command
left_xyz, left_quat, left_gripper = self.interpolate(self.curr_left_waypoint, next_left_waypoint, self.step_count)
right_xyz, right_quat, right_gripper = self.interpolate(self.curr_right_waypoint, next_right_waypoint, self.step_count)
# Inject noise
if self.inject_noise:
scale = 0.01
left_xyz = left_xyz + np.random.uniform(-scale, scale, left_xyz.shape)
right_xyz = right_xyz + np.random.uniform(-scale, scale, right_xyz.shape)
action_left = np.concatenate([left_xyz, left_quat, [left_gripper]])
action_right = np.concatenate([right_xyz, right_quat, [right_gripper]])
self.step_count += 1
return np.concatenate([action_left, action_right])
class PickAndTransferPolicy(BasePolicy):
def generate_trajectory(self, ts_first):
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
box_info = np.array(ts_first.observation['env_state'])
box_xyz = box_info[:3]
box_quat = box_info[3:]
# print(f"Generate trajectory for {box_xyz=}")
gripper_pick_quat = Quaternion(init_mocap_pose_right[3:])
gripper_pick_quat = gripper_pick_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
meet_left_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)
meet_xyz = np.array([0, 0.5, 0.25])
self.left_trajectory = [
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
{"t": 100, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # approach meet position
{"t": 260, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # move to meet position
{"t": 310, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 0}, # close gripper
{"t": 360, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # move left
{"t": 400, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # stay
]
self.right_trajectory = [
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
{"t": 90, "xyz": box_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat.elements, "gripper": 1}, # approach the cube
{"t": 130, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 1}, # go down
{"t": 170, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 0}, # close gripper
{"t": 200, "xyz": meet_xyz + np.array([0.05, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 0}, # approach meet position
{"t": 220, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 0}, # move to meet position
{"t": 310, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 1}, # open gripper
{"t": 360, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # move to right
{"t": 400, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # stay
]
class InsertionPolicy(BasePolicy):
def generate_trajectory(self, ts_first):
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
peg_info = np.array(ts_first.observation['env_state'])[:7]
peg_xyz = peg_info[:3]
peg_quat = peg_info[3:]
socket_info = np.array(ts_first.observation['env_state'])[7:]
socket_xyz = socket_info[:3]
socket_quat = socket_info[3:]
gripper_pick_quat_right = Quaternion(init_mocap_pose_right[3:])
gripper_pick_quat_right = gripper_pick_quat_right * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
gripper_pick_quat_left = Quaternion(init_mocap_pose_right[3:])
gripper_pick_quat_left = gripper_pick_quat_left * Quaternion(axis=[0.0, 1.0, 0.0], degrees=60)
meet_xyz = np.array([0, 0.5, 0.15])
lift_right = 0.00715
self.left_trajectory = [
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
{"t": 120, "xyz": socket_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # approach the cube
{"t": 170, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # go down
{"t": 220, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # close gripper
{"t": 285, "xyz": meet_xyz + np.array([-0.1, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # approach meet position
{"t": 340, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements,"gripper": 0}, # insertion
{"t": 400, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # insertion
]
self.right_trajectory = [
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
{"t": 120, "xyz": peg_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # approach the cube
{"t": 170, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # go down
{"t": 220, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # close gripper
{"t": 285, "xyz": meet_xyz + np.array([0.1, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # approach meet position
{"t": 340, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
{"t": 400, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
]
def test_policy(task_name):
# example rolling out pick_and_transfer policy
onscreen_render = True
inject_noise = False
# setup the environment
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
if 'sim_transfer_cube' in task_name:
env = make_ee_sim_env('sim_transfer_cube')
elif 'sim_insertion' in task_name:
env = make_ee_sim_env('sim_insertion')
else:
raise NotImplementedError
for episode_idx in range(2):
ts = env.reset()
episode = [ts]
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images']['angle'])
plt.ion()
policy = PickAndTransferPolicy(inject_noise)
for step in range(episode_len):
action = policy(ts)
ts = env.step(action)
episode.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images']['angle'])
plt.pause(0.02)
plt.close()
episode_return = np.sum([ts.reward for ts in episode[1:]])
if episode_return > 0:
print(f"{episode_idx=} Successful, {episode_return=}")
else:
print(f"{episode_idx=} Failed")
if __name__ == '__main__':
test_task_name = 'sim_transfer_cube_scripted'
test_policy(test_task_name)

View File

@@ -0,0 +1,278 @@
import numpy as np
import os
import collections
import matplotlib.pyplot as plt
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from constants import DT, XML_DIR, START_ARM_POSE
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
import IPython
e = IPython.embed
BOX_POSE = [None] # to be changed from outside
def make_sim_env(task_name):
"""
Environment for simulated robot bi-manual manipulation, with joint position control
Action space: [left_arm_qpos (6), # absolute joint position
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
if 'sim_transfer_cube' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_transfer_cube.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = TransferCubeTask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
elif 'sim_insertion' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_insertion.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = InsertionTask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
else:
raise NotImplementedError
return env
class BimanualViperXTask(base.Task):
def __init__(self, random=None):
super().__init__(random=random)
def before_step(self, action, physics):
left_arm_action = action[:6]
right_arm_action = action[7:7+6]
normalized_left_gripper_action = action[6]
normalized_right_gripper_action = action[7+6]
left_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_left_gripper_action)
right_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_right_gripper_action)
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
env_action = np.concatenate([left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action])
super().before_step(env_action, physics)
return
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
super().initialize_episode(physics)
@staticmethod
def get_qpos(physics):
qpos_raw = physics.data.qpos.copy()
left_qpos_raw = qpos_raw[:8]
right_qpos_raw = qpos_raw[8:16]
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
@staticmethod
def get_qvel(physics):
qvel_raw = physics.data.qvel.copy()
left_qvel_raw = qvel_raw[:8]
right_qvel_raw = qvel_raw[8:16]
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
@staticmethod
def get_env_state(physics):
raise NotImplementedError
def get_observation(self, physics):
obs = collections.OrderedDict()
obs['qpos'] = self.get_qpos(physics)
obs['qvel'] = self.get_qvel(physics)
obs['env_state'] = self.get_env_state(physics)
obs['images'] = dict()
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
return obs
def get_reward(self, physics):
# return whether left gripper is holding the box
raise NotImplementedError
class TransferCubeTask(BimanualViperXTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
# reset qpos, control and box position
with physics.reset_context():
physics.named.data.qpos[:16] = START_ARM_POSE
np.copyto(physics.data.ctrl, START_ARM_POSE)
assert BOX_POSE[0] is not None
physics.named.data.qpos[-7:] = BOX_POSE[0]
# print(f"{BOX_POSE=}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether left gripper is holding the box
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_table = ("red_box", "table") in all_contact_pairs
reward = 0
if touch_right_gripper:
reward = 1
if touch_right_gripper and not touch_table: # lifted
reward = 2
if touch_left_gripper: # attempted transfer
reward = 3
if touch_left_gripper and not touch_table: # successful transfer
reward = 4
return reward
class InsertionTask(BimanualViperXTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
# reset qpos, control and box position
with physics.reset_context():
physics.named.data.qpos[:16] = START_ARM_POSE
np.copyto(physics.data.ctrl, START_ARM_POSE)
assert BOX_POSE[0] is not None
physics.named.data.qpos[-7*2:] = BOX_POSE[0] # two objects
# print(f"{BOX_POSE=}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether peg touches the pin
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
peg_touch_table = ("red_peg", "table") in all_contact_pairs
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
("socket-2", "table") in all_contact_pairs or \
("socket-3", "table") in all_contact_pairs or \
("socket-4", "table") in all_contact_pairs
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
("red_peg", "socket-2") in all_contact_pairs or \
("red_peg", "socket-3") in all_contact_pairs or \
("red_peg", "socket-4") in all_contact_pairs
pin_touched = ("red_peg", "pin") in all_contact_pairs
reward = 0
if touch_left_gripper and touch_right_gripper: # touch both
reward = 1
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
reward = 2
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
reward = 3
if pin_touched: # successful insertion
reward = 4
return reward
def get_action(master_bot_left, master_bot_right):
action = np.zeros(14)
# arm action
action[:6] = master_bot_left.dxl.joint_states.position[:6]
action[7:7+6] = master_bot_right.dxl.joint_states.position[:6]
# gripper action
left_gripper_pos = master_bot_left.dxl.joint_states.position[7]
right_gripper_pos = master_bot_right.dxl.joint_states.position[7]
normalized_left_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(left_gripper_pos)
normalized_right_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(right_gripper_pos)
action[6] = normalized_left_pos
action[7+6] = normalized_right_pos
return action
def test_sim_teleop():
""" Testing teleoperation in sim with ALOHA. Requires hardware and ALOHA repo to work. """
from interbotix_xs_modules.arm import InterbotixManipulatorXS
BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0]
# source of data
master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
robot_name=f'master_left', init_node=True)
master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
robot_name=f'master_right', init_node=False)
# setup the environment
env = make_sim_env('sim_transfer_cube')
ts = env.reset()
episode = [ts]
# setup plotting
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images']['angle'])
plt.ion()
for t in range(1000):
action = get_action(master_bot_left, master_bot_right)
ts = env.step(action)
episode.append(ts)
plt_img.set_data(ts.observation['images']['angle'])
plt.pause(0.02)
if __name__ == '__main__':
test_sim_teleop()

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,575 @@
import os
import time
import yaml
import torch
import pickle
import dm_env
import logging
import collections
import numpy as np
import tracemalloc
from einops import rearrange
import matplotlib.pyplot as plt
from torchvision import transforms
from shadow_rm_robot.realman_arm import RmArm
from shadow_camera.realsense import RealSenseCamera
from shadow_act.models.latent_model import Latent_Model_Transformer
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
from shadow_act.utils.utils import set_seed
# 配置logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# # 隐藏h5py的警告Creating converter from 7 to 5
# logging.getLogger("h5py").setLevel(logging.WARNING)
class RmActEvaluator:
def __init__(self, config, save_episode=True, num_rollouts=50):
"""
初始化Evaluator类
Args:
config (dict): 配置字典
checkpoint_name (str): 检查点名称
save_episode (bool): 是否保存每个episode
num_rollouts (int): 滚动次数
"""
self.config = config
self._seed = config["seed"]
self.robot_env = config["robot_env"]
self.checkpoint_dir = config["checkpoint_dir"]
self.checkpoint_name = config["checkpoint_name"]
self.save_episode = save_episode
self.num_rollouts = num_rollouts
self.state_dim = config["state_dim"]
self.real_robot = config["real_robot"]
self.policy_class = config["policy_class"]
self.onscreen_render = config["onscreen_render"]
self.camera_names = config["camera_names"]
self.max_timesteps = config["episode_len"]
self.task_name = config["task_name"]
self.temporal_agg = config["temporal_agg"]
self.onscreen_cam = "angle"
self.policy_config = config["policy_config"]
self.vq = config["policy_config"]["vq"]
# self.actuator_config = config["actuator_config"]
# self.use_actuator_net = self.actuator_config["actuator_network_dir"] is not None
self.stats = None
self.env = None
self.env_max_reward = 0
def _make_policy(self, policy_class, policy_config):
"""
根据策略类和配置创建策略对象
Args:
policy_class (str): 策略类名称
policy_config (dict): 策略配置字典
Returns:
policy: 创建的策略对象
"""
if policy_class == "ACT":
return ACTPolicy(policy_config)
elif policy_class == "CNNMLP":
return CNNMLPPolicy(policy_config)
elif policy_class == "Diffusion":
return DiffusionPolicy(policy_config)
else:
raise NotImplementedError(f"Policy class {policy_class} is not implemented")
def load_policy_and_stats(self):
"""
加载策略和统计数据
"""
checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name)
logging.info(f"Loading policy from: {checkpoint_path}")
self.policy = self._make_policy(self.policy_class, self.policy_config)
# 加载模型并设置为评估模式
self.policy.load_state_dict(torch.load(checkpoint_path, weights_only=True))
self.policy.cuda()
self.policy.eval()
if self.vq:
vq_dim = self.config["policy_config"]["vq_dim"]
vq_class = self.config["policy_config"]["vq_class"]
self.latent_model = Latent_Model_Transformer(vq_dim, vq_dim, vq_class)
latent_model_checkpoint_path = os.path.join(
self.checkpoint_dir, "latent_model_last.ckpt"
)
self.latent_model.deserialize(torch.load(latent_model_checkpoint_path))
self.latent_model.eval()
self.latent_model.cuda()
logging.info(
f"Loaded policy from: {checkpoint_path}, latent model from: {latent_model_checkpoint_path}"
)
else:
logging.info(f"Loaded: {checkpoint_path}")
stats_path = os.path.join(self.checkpoint_dir, "dataset_stats.pkl")
with open(stats_path, "rb") as f:
self.stats = pickle.load(f)
def pre_process(self, state_qpos):
"""
预处理状态位置
Args:
state_qpos (np.array): 状态位置数组
Returns:
np.array: 预处理后的状态位置
"""
if self.policy_class == "Diffusion":
return ((state_qpos + 1) / 2) * (
self.stats["action_max"] - self.stats["action_min"]
) + self.stats["action_min"]
# 标准化处理,均值为 0标准差为 1
return (state_qpos - self.stats["qpos_mean"]) / self.stats["qpos_std"]
def post_process(self, action):
"""
后处理动作
Args:
action (np.array): 动作数组
Returns:
np.array: 后处理后的动作
"""
# 反标准化处理
return action * self.stats["action_std"] + self.stats["action_mean"]
def get_image_torch(self, timestep, camera_names, random_crop_resize=False):
"""
获取图像
Args:
timestep (object): 时间步对象
camera_names (list): 相机名称列表
random_crop_resize (bool): 是否随机裁剪和调整大小
Returns:
torch.Tensor: 处理后的图像,归一化(num_cameras, channels, height, width)
"""
current_images = []
for cam_name in camera_names:
current_image = rearrange(
timestep.observation["images"][cam_name], "h w c -> c h w"
)
current_images.append(current_image)
current_image = np.stack(current_images, axis=0)
current_image = (
torch.from_numpy(current_image / 255.0).float().cuda().unsqueeze(0)
)
if random_crop_resize:
logging.info("Random crop resize is used!")
original_size = current_image.shape[-2:]
ratio = 0.95
current_image = current_image[
...,
int(original_size[0] * (1 - ratio) / 2) : int(
original_size[0] * (1 + ratio) / 2
),
int(original_size[1] * (1 - ratio) / 2) : int(
original_size[1] * (1 + ratio) / 2
),
]
current_image = current_image.squeeze(0)
resize_transform = transforms.Resize(original_size, antialias=True)
current_image = resize_transform(current_image)
current_image = current_image.unsqueeze(0)
return current_image
def load_environment(self):
"""
加载环境
"""
if self.real_robot:
self.env = DeviceAloha(self.robot_env)
self.env_max_reward = 0
else:
from sim_env import make_sim_env
self.env = make_sim_env(self.task_name)
self.env_max_reward = self.env.task.max_reward
def get_auto_index(self, checkpoint_dir):
max_idx = 1000
for i in range(max_idx + 1):
if not os.path.isfile(os.path.join(checkpoint_dir, f"qpos_{i}.npy")):
return i
raise Exception(f"Error getting auto index, or more than {max_idx} episodes")
def evaluate(self, checkpoint_name=None):
"""
评估策略
Returns:
tuple: 成功率和平均回报
"""
if checkpoint_name is not None:
self.checkpoint_name = checkpoint_name
set_seed(self._seed) # np与torch的随机种子
self.load_policy_and_stats()
self.load_environment()
query_frequency = self.policy_config["num_queries"]
# 时间聚合时每个时间步只有1个查询
if self.temporal_agg:
query_frequency = 1
num_queries = self.policy_config["num_queries"]
# # 真实机器人时基础延迟为13
# if self.real_robot:
# BASE_DELAY = 13
# # query_frequency -= BASE_DELAY
max_timesteps = int(self.max_timesteps * 1) # may increase for real-world tasks
episode_returns = []
highest_rewards = []
for rollout_id in range(self.num_rollouts):
timestep = self.env.reset()
if self.onscreen_render:
# TODO 画图
pass
if self.temporal_agg:
all_time_actions = torch.zeros(
[max_timesteps, max_timesteps + num_queries, self.state_dim]
).cuda()
qpos_history_raw = np.zeros((max_timesteps, self.state_dim))
rewards = []
with torch.inference_mode():
time_0 = time.time()
DT = 1 / 30
culmulated_delay = 0
for t in range(max_timesteps):
time_1 = time.time()
if self.onscreen_render:
# TODO 显示图像
pass
# process previous timestep to get qpos and image_list
obs = timestep.observation
qpos_numpy = np.array(obs["qpos"])
qpos_history_raw[t] = qpos_numpy
qpos = self.pre_process(qpos_numpy)
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
logging.info(f"t{t}")
if t % query_frequency == 0:
current_image = self.get_image_torch(
timestep,
self.camera_names,
random_crop_resize=(
self.config["policy_class"] == "Diffusion"
),
)
if t == 0:
# 网络预热
for _ in range(10):
self.policy(qpos, current_image)
logging.info("Network warm up done")
if self.config["policy_class"] == "ACT":
if t % query_frequency == 0:
if self.vq:
if rollout_id == 0:
for _ in range(10):
vq_sample = self.latent_model.generate(
1, temperature=1, x=None
)
logging.info(
torch.nonzero(vq_sample[0])[:, 1]
.cpu()
.numpy()
)
vq_sample = self.latent_model.generate(
1, temperature=1, x=None
)
all_actions = self.policy(
qpos, current_image, vq_sample=vq_sample
)
else:
all_actions = self.policy(qpos, current_image)
# if self.real_robot:
# all_actions = torch.cat(
# [
# all_actions[:, :-BASE_DELAY, :-2],
# all_actions[:, BASE_DELAY:, -2:],
# ],
# dim=2,
# )
if self.temporal_agg:
all_time_actions[[t], t : t + num_queries] = all_actions
actions_for_curr_step = all_time_actions[:, t]
actions_populated = torch.all(
actions_for_curr_step != 0, axis=1
)
actions_for_curr_step = actions_for_curr_step[
actions_populated
]
k = 0.01
exp_weights = np.exp(
-k * np.arange(len(actions_for_curr_step))
)
exp_weights = exp_weights / exp_weights.sum()
exp_weights = (
torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
)
raw_action = (actions_for_curr_step * exp_weights).sum(
dim=0, keepdim=True
)
else:
raw_action = all_actions[:, t % query_frequency]
elif self.config["policy_class"] == "Diffusion":
if t % query_frequency == 0:
all_actions = self.policy(qpos, current_image)
# if self.real_robot:
# all_actions = torch.cat(
# [
# all_actions[:, :-BASE_DELAY, :-2],
# all_actions[:, BASE_DELAY:, -2:],
# ],
# dim=2,
# )
raw_action = all_actions[:, t % query_frequency]
elif self.config["policy_class"] == "CNNMLP":
raw_action = self.policy(qpos, current_image)
all_actions = raw_action.unsqueeze(0)
else:
raise NotImplementedError
### post-process actions
raw_action = raw_action.squeeze(0).cpu().numpy()
action = self.post_process(raw_action)
### step the environment
if self.real_robot:
logging.info(f" action = {action}")
timestep = self.env.step(action)
rewards.append(timestep.reward)
duration = time.time() - time_1
sleep_time = max(0, DT - duration)
time.sleep(sleep_time)
if duration >= DT:
culmulated_delay += duration - DT
logging.warning(
f"Warning: step duration: {duration:.3f} s at step {t} longer than DT: {DT} s, culmulated delay: {culmulated_delay:.3f} s"
)
logging.info(f"Avg fps {max_timesteps / (time.time() - time_0)}")
plt.close()
if self.real_robot:
log_id = self.get_auto_index(self.checkpoint_dir)
np.save(
os.path.join(self.checkpoint_dir, f"qpos_{log_id}.npy"),
qpos_history_raw,
)
plt.figure(figsize=(10, 20))
for i in range(self.state_dim):
plt.subplot(self.state_dim, 1, i + 1)
plt.plot(qpos_history_raw[:, i])
if i != self.state_dim - 1:
plt.xticks([])
plt.tight_layout()
plt.savefig(os.path.join(self.checkpoint_dir, f"qpos_{log_id}.png"))
plt.close()
rewards = np.array(rewards)
episode_return = np.sum(rewards[rewards != None])
episode_returns.append(episode_return)
episode_highest_reward = np.max(rewards)
highest_rewards.append(episode_highest_reward)
logging.info(
f"Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {self.env_max_reward=}, Success: {episode_highest_reward == self.env_max_reward}"
)
success_rate = np.mean(np.array(highest_rewards) == self.env_max_reward)
avg_return = np.mean(episode_returns)
summary_str = (
f"\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n"
)
for r in range(self.env_max_reward + 1):
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
more_or_equal_r_rate = more_or_equal_r / self.num_rollouts
summary_str += f"Reward >= {r}: {more_or_equal_r}/{self.num_rollouts} = {more_or_equal_r_rate * 100}%\n"
logging.info(summary_str)
result_file_name = "result_" + self.checkpoint_name.split(".")[0] + ".txt"
with open(os.path.join(self.checkpoint_dir, result_file_name), "w") as f:
f.write(summary_str)
f.write(repr(episode_returns))
f.write("\n\n")
f.write(repr(highest_rewards))
return success_rate, avg_return
class DeviceAloha:
def __init__(self, aloha_config):
"""
初始化设备
Args:
device_name (str): 设备名称
"""
config_left_arm = aloha_config["rm_left_arm"]
config_right_arm = aloha_config["rm_right_arm"]
config_head_camera = aloha_config["head_camera"]
config_bottom_camera = aloha_config["bottom_camera"]
config_left_camera = aloha_config["left_camera"]
config_right_camera = aloha_config["right_camera"]
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
self.arm_axis = aloha_config["arm_axis"]
self.arm_left = RmArm(config_left_arm)
self.arm_right = RmArm(config_right_arm)
self.camera_left = RealSenseCamera(config_head_camera, False)
self.camera_right = RealSenseCamera(config_bottom_camera, False)
self.camera_bottom = RealSenseCamera(config_left_camera, False)
self.camera_top = RealSenseCamera(config_right_camera, False)
self.camera_left.start_camera()
self.camera_right.start_camera()
self.camera_bottom.start_camera()
self.camera_top.start_camera()
def close(self):
"""
关闭摄像头
"""
self.camera_left.close()
self.camera_right.close()
self.camera_bottom.close()
self.camera_top.close()
def get_qps(self):
"""
获取关节角度
Returns:
np.array: 关节角度
"""
left_slave_arm_angle = self.arm_left.get_joint_angle()
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
right_slave_arm_angle = self.arm_right.get_joint_angle()
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
def get_qvel(self):
"""
获取关节速度
Returns:
np.array: 关节速度
"""
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
def get_effort(self):
"""
获取关节力
Returns:
np.array: 关节力
"""
left_slave_arm_effort = self.arm_left.get_joint_effort()
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
right_slave_arm_effort = self.arm_right.get_joint_effort()
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
def get_images(self):
"""
获取图像
Returns:
dict: 图像字典
"""
self.top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
self.bottom_image, _, _, _ = self.camera_bottom.read_frame(
True, False, False, False
)
self.left_image, _, _, _ = self.camera_left.read_frame(
True, False, False, False
)
self.right_image, _, _, _ = self.camera_right.read_frame(
True, False, False, False
)
return {
"cam_high": self.top_image,
"cam_low": self.bottom_image,
"cam_left": self.left_image,
"cam_right": self.right_image,
}
def get_observation(self):
obs = collections.OrderedDict()
obs["qpos"] = self.get_qps()
obs["qvel"] = self.get_qvel()
obs["effort"] = self.get_effort()
obs["images"] = self.get_images()
return obs
def reset(self):
logging.info("Resetting the environment")
self.arm_left.set_joint_position(self.init_left_arm_angle[0:self.arm_axis])
self.arm_right.set_joint_position(self.init_right_arm_angle[0:self.arm_axis])
self.arm_left.set_gripper_position(0)
self.arm_right.set_gripper_position(0)
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=0,
discount=None,
observation=self.get_observation(),
)
def step(self, target_angle):
self.arm_left.set_joint_canfd_position(target_angle[0:self.arm_axis])
self.arm_right.set_joint_canfd_position(target_angle[self.arm_axis+1:self.arm_axis*2+1])
self.arm_left.set_gripper_position(target_angle[self.arm_axis])
self.arm_right.set_gripper_position(target_angle[(self.arm_axis*2 + 1)])
return dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=0,
discount=None,
observation=self.get_observation(),
)
if __name__ == "__main__":
# with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
# config = yaml.safe_load(f)
# aloha_config = config["robot_env"]
# device = DeviceAloha(aloha_config)
# device.reset()
# while True:
# init_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
# time_step = time.time()
# timestep = device.step(init_angle)
# logging.info(f"Time: {time.time() - time_step}")
# obs = timestep.observation
with open("/home/wang/project/shadow_rm_act/config/config.yaml", "r") as f:
config = yaml.safe_load(f)
# logging.info(f"Config: {config}")
evaluator = RmActEvaluator(config)
success_rate, avg_return = evaluator.evaluate()

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,153 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Backbone modules.
"""
import torch
import torchvision
from torch import nn
from typing import Dict, List
import torch.nn.functional as F
from .position_encoding import build_position_encoding
from torchvision.models import ResNet18_Weights
from torchvision.models._utils import IntermediateLayerGetter
from shadow_act.utils.misc import NestedTensor, is_main_process
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self,
backbone: nn.Module,
train_backbone: bool,
num_channels: int,
return_interm_layers: bool,
):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {"layer4": "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(
self,
name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool,
):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
weights=ResNet18_Weights.IMAGENET1K_V1 if is_main_process() else None,
norm_layer=FrozenBatchNorm2d,
)
# backbone = getattr(torchvision.models, name)(
# replace_stride_with_dilation=[False, False, dilation],
# pretrained=is_main_process(),
# norm_layer=FrozenBatchNorm2d,
# ) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
):
position_embedding = build_position_encoding(
hidden_dim=hidden_dim, position_embedding_type=position_embedding_type
)
train_backbone = lr_backbone > 0
return_interm_layers = masks
backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

View File

@@ -0,0 +1,436 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from shadow_act.models.transformer import Transformer
from .backbone import build_backbone
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
import numpy as np
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DETRVAE(nn.Module):
"""This is the DETR module that performs object detection"""
def __init__(
self,
backbones,
transformer,
encoder,
state_dim,
num_queries,
camera_names,
vq,
vq_class,
vq_dim,
action_dim,
):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
self.state_dim, self.action_dim = state_dim, action_dim
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(
backbones[0].num_channels, hidden_dim, kernel_size=1
)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(
action_dim, hidden_dim
) # project action to embedding
self.encoder_joint_proj = nn.Linear(
action_dim, hidden_dim
) # project qpos to embedding
if self.vq:
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
else:
self.latent_proj = nn.Linear(
hidden_dim, self.latent_dim * 2
) # project hidden state to latent std, var
self.register_buffer(
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
) # [CLS], qpos, a_seq
# decoder extra parameters
if self.vq:
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
else:
self.latent_out_proj = nn.Linear(
self.latent_dim, hidden_dim
) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(
2, hidden_dim
) # learned position embedding for proprio and latent
def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
bs, _ = qpos.shape
if self.encoder is None:
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
qpos.device
)
latent_input = self.latent_out_proj(latent_sample)
probs = binaries = mu = logvar = None
else:
# cvae encoder
is_training = actions is not None # train or val
### Obtain latent z from action sequence
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(
actions
) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
bs, 1, 1
) # (bs, 1, hidden_dim)
encoder_input = torch.cat(
[cls_embed, qpos_embed, action_embed], axis=1
) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(
1, 0, 2
) # (seq+1, bs, hidden_dim)
# do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(
qpos.device
) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(
encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
if self.vq:
logits = latent_info.reshape(
[*latent_info.shape[:-1], self.vq_class, self.vq_dim]
)
probs = torch.softmax(logits, dim=-1)
binaries = (
F.one_hot(
torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(
-1
),
self.vq_dim,
)
.view(-1, self.vq_class, self.vq_dim)
.float()
)
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
latent_input = self.latent_out_proj(straigt_through)
mu = logvar = None
else:
probs = binaries = None
mu = latent_info[:, : self.latent_dim]
logvar = latent_info[:, self.latent_dim :]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = binaries = probs = None
if self.vq:
latent_input = self.latent_out_proj(
vq_sample.view(-1, self.vq_class * self.vq_dim)
)
else:
latent_sample = torch.zeros(
[bs, self.latent_dim], dtype=torch.float32
).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
return latent_input, probs, binaries, mu, logvar
def forward(
self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None
):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
latent_input, probs, binaries, mu, logvar = self.encode(
qpos, actions, is_pad, vq_sample
)
# cvae decoder
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, cam_name in enumerate(self.camera_names):
# TODO: fix this error
features, pos = self.backbones[0](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(
src,
None,
self.query_embed.weight,
pos,
latent_input,
proprio_input,
self.additional_pos_embed.weight,
)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(
transformer_input, None, self.query_embed.weight, self.pos.weight
)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar], probs, binaries
class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
for backbone in backbones:
down_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
nn.Conv2d(128, 64, kernel_size=5),
nn.Conv2d(64, 32, kernel_size=5),
)
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + state_dim
self.mlp = mlp(
input_dim=mlp_in_dim,
hidden_dim=1024,
output_dim=self.action_dim,
hidden_depth=2,
)
else:
raise NotImplementedError
def forward(self, qpos, image, env_state, actions=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
# Image observation features and position embeddings
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0] # not used
all_cam_features.append(self.backbone_down_projs[cam_id](features))
# flatten everything
flattened_features = []
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
a_hat = self.mlp(features)
return a_hat
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for i in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
trunk = nn.Sequential(*mods)
return trunk
def build_encoder(
hidden_dim, # 256
dropout, # 0.1
nheads, # 8
dim_feedforward,
num_encoder_layers, # 4 # TODO shared with VAE decoder
normalize_before, # False
):
activation = "relu"
encoder_layer = TransformerEncoderLayer(
hidden_dim, nheads, dim_feedforward, dropout, activation, normalize_before
)
encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
return encoder
def build_vae(
hidden_dim,
state_dim,
position_embedding_type,
lr_backbone,
masks,
backbone,
dilation,
dropout,
nheads,
dim_feedforward,
enc_layers,
dec_layers,
pre_norm,
num_queries,
camera_names,
vq,
vq_class,
vq_dim,
action_dim,
no_encoder,
):
# TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
backbone = build_backbone(
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
)
backbones.append(backbone)
transformer = build_transformer(
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
)
if no_encoder:
encoder = None
else:
encoder = build_encoder(
hidden_dim,
dropout,
nheads,
dim_feedforward,
enc_layers,
pre_norm,
)
model = DETRVAE(
backbones,
transformer,
encoder,
state_dim,
num_queries,
camera_names,
vq,
vq_class,
vq_dim,
action_dim,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
return model
# TODO
def build_cnnmlp(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
model = CNNMLP(
backbones,
state_dim=state_dim,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
return model

Some files were not shown because too many files have changed in this diff Show More