forked from tangger/lerobot
Compare commits
13 Commits
user/adil-
...
pre-commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d20ee7655 | ||
|
|
1537d0ab90 | ||
|
|
2be7f3a3ff | ||
|
|
0cf864870c | ||
|
|
1786916a16 | ||
|
|
0507ad4f68 | ||
|
|
bed90e3a41 | ||
|
|
6163daaaa4 | ||
|
|
8e2a394442 | ||
|
|
a445d9c9da | ||
|
|
f24030d4d8 | ||
|
|
7598aeaad7 | ||
|
|
4485cc0b5b |
24
.github/workflows/build-docker-images.yml
vendored
24
.github/workflows/build-docker-images.yml
vendored
@@ -40,24 +40,24 @@ jobs:
|
||||
git lfs install
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push CPU
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-cpu/Dockerfile
|
||||
@@ -78,24 +78,24 @@ jobs:
|
||||
git lfs install
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-gpu/Dockerfile
|
||||
@@ -110,23 +110,23 @@ jobs:
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU dev
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-gpu-dev/Dockerfile
|
||||
|
||||
4
.github/workflows/nightly-tests.yml
vendored
4
.github/workflows/nightly-tests.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
container:
|
||||
image: huggingface/lerobot-cpu:latest
|
||||
image: huggingface/lerobot-cpu:latest # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu"
|
||||
container:
|
||||
image: huggingface/lerobot-gpu:latest
|
||||
image: huggingface/lerobot-gpu:latest # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
|
||||
8
.github/workflows/quality.yml
vendored
8
.github/workflows/quality.yml
vendored
@@ -33,12 +33,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
@@ -64,9 +64,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.29.10
|
||||
uses: crate-ci/typos@db35ee91e80fbb447f33b0e5fbddb24d2a1a884f # v1.29.10
|
||||
|
||||
8
.github/workflows/test-docker-build.yml
vendored
8
.github/workflows/test-docker-build.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -64,17 +64,17 @@ jobs:
|
||||
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
with:
|
||||
file: ${{ matrix.docker-file }}
|
||||
context: .
|
||||
|
||||
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
@@ -129,7 +129,7 @@ jobs:
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
|
||||
4
.github/workflows/trufflehog.yml
vendored
4
.github/workflows/trufflehog.yml
vendored
@@ -24,12 +24,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35
|
||||
with:
|
||||
extra_args: --only-verified
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -26,7 +26,6 @@ outputs
|
||||
|
||||
# VS Code
|
||||
.vscode
|
||||
.devcontainer
|
||||
|
||||
# HPC
|
||||
nautilus/*.yaml
|
||||
|
||||
@@ -37,19 +37,18 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/adhtruong/mirrors-typos
|
||||
rev: v1.31.1
|
||||
rev: v1.32.0
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.19.1
|
||||
rev: v3.20.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
# Exclude generated protobuf files
|
||||
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.5
|
||||
rev: v0.11.12
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -58,12 +57,12 @@ repos:
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.3
|
||||
rev: v8.27.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.5.2
|
||||
rev: v1.9.0
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
|
||||
15
README.md
15
README.md
@@ -360,7 +360,7 @@ with profile(
|
||||
If you want, you can cite this work with:
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
@@ -408,19 +408,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
- [HIL-SERL](https://hil-serl.github.io/)
|
||||
```bibtex
|
||||
@Article{luo2024hilserl,
|
||||
title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
|
||||
author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine},
|
||||
year={2024},
|
||||
eprint={2410.21845},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO}
|
||||
}
|
||||
```
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
@@ -96,8 +96,8 @@ Reconnect the usb cable.
|
||||
#### Update config file
|
||||
|
||||
Now that you have your ports, update the **port** default values of [`SO101RobotConfig`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robot_devices/robots/configs.py).
|
||||
You will find something like, update the `port` values with your actual motor ports:
|
||||
```python
|
||||
You will find a class called `so101` where you can update the `port` values with your actual motor ports:
|
||||
```diff
|
||||
@RobotConfig.register_subclass("so101")
|
||||
@dataclass
|
||||
class So101RobotConfig(ManipulatorRobotConfig):
|
||||
@@ -110,7 +110,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
|
||||
- port="/dev/tty.usbmodem58760431091",
|
||||
+ port="{ADD YOUR LEADER PORT}",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
@@ -127,7 +128,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
- port="/dev/tty.usbmodem585A0076891",
|
||||
+ port="{ADD YOUR FOLLOWER PORT}",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
@@ -297,7 +299,7 @@ Remove all support material from the 3D-printed parts, the easiest way to do thi
|
||||
##### Wiring
|
||||
|
||||
- Attach the motor controller on the back.
|
||||
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themself and stay in place.
|
||||
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themselves and stay in place.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
|
||||
@@ -83,7 +83,7 @@ camera_01_frame_000047.png
|
||||
|
||||
Note: Some cameras may take a few seconds to warm up, and the first frame might be black or green.
|
||||
|
||||
Now that you have the camera indexes, you should specify the camera's in the config. TODO(pepijn): add more info about setting camera config, rotate etc..
|
||||
Now that you have the camera indexes, you should specify the camera's in the config.
|
||||
|
||||
### Use your phone
|
||||
<hfoptions id="use phone">
|
||||
@@ -152,7 +152,7 @@ If everything is set up correctly, you can proceed with the rest of the tutorial
|
||||
|
||||
## Teleoperate with cameras
|
||||
|
||||
We can now teleoperate again while at the same time visualizing the camera's and joint positions with `rerun`.
|
||||
We can now teleoperate again while at the same time visualizing the cameras and joint positions with `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
|
||||
@@ -55,7 +55,7 @@ conda install ffmpeg -c conda-forge
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
cd lerobot && pip install ".[feetech]"
|
||||
cd lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
@@ -128,7 +128,7 @@ sudo chmod 666 /dev/ttyACM1
|
||||
#### d. Update config file
|
||||
|
||||
IMPORTANTLY: Now that you have your ports, update the **port** default values of [`SO100RobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
|
||||
```python
|
||||
```diff
|
||||
@RobotConfig.register_subclass("so100")
|
||||
@dataclass
|
||||
class So100RobotConfig(ManipulatorRobotConfig):
|
||||
@@ -141,7 +141,8 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
|
||||
- port="/dev/tty.usbmodem58760431091",
|
||||
+ port="{ADD YOUR LEADER PORT}",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
@@ -158,7 +159,8 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
- port="/dev/tty.usbmodem585A0076891",
|
||||
+ port="{ADD YOUR FOLLOWER PORT}",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
|
||||
@@ -141,7 +141,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
Note: These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
|
||||
@@ -61,7 +61,7 @@ conda install ffmpeg -c conda-forge
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
cd lerobot && pip install ".[feetech]"
|
||||
cd lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
@@ -145,8 +145,8 @@ Reconnect the usb cable.
|
||||
#### Update config file
|
||||
|
||||
Now that you have your ports, update the **port** default values of [`SO101RobotConfig`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robot_devices/robots/configs.py).
|
||||
You will find something a class called `so101` where you can update the `port` values with your actual motor ports:
|
||||
```python
|
||||
You will find a class called `so101` where you can update the `port` values with your actual motor ports:
|
||||
```diff
|
||||
@RobotConfig.register_subclass("so101")
|
||||
@dataclass
|
||||
class So101RobotConfig(ManipulatorRobotConfig):
|
||||
@@ -159,7 +159,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
|
||||
- port="/dev/tty.usbmodem58760431091",
|
||||
+ port="{ADD YOUR LEADER PORT}",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
@@ -176,7 +177,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
- port="/dev/tty.usbmodem585A0076891",
|
||||
+ port="{ADD YOUR FOLLOWER PORT}",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
@@ -308,7 +310,7 @@ Remove all support material from the 3D-printed parts.
|
||||
##### Wiring
|
||||
|
||||
- Attach the motor controller on the back.
|
||||
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themself and stay in place.
|
||||
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themselves and stay in place.
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/4c2cacfd-9276-4ee4-8bf2-ba2492667b78" type="video/mp4"></video>
|
||||
|
||||
@@ -351,7 +353,7 @@ python lerobot/scripts/control_robot.py \
|
||||
```
|
||||
## Control your robot
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Next we will explain you how to train a neural network to autonomously control a real robot.
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Next we will explain to you how to train a neural network to autonomously control a real robot.
|
||||
|
||||
**You'll learn to:**
|
||||
1. How to record and visualize your dataset.
|
||||
@@ -515,7 +517,7 @@ If you have an additional camera you can add a wrist camera to the SO101. There
|
||||
|
||||
## Teleoperate with cameras
|
||||
|
||||
We can now teleoperate again while at the same time visualizing the camera's and joint positions with `rerun`.
|
||||
We can now teleoperate again while at the same time visualizing the cameras and joint positions with `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
@@ -526,7 +528,7 @@ python lerobot/scripts/control_robot.py \
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-101.
|
||||
|
||||
We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens).
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
|
||||
class AsyncImageWriter:
|
||||
"""
|
||||
This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
save images on disk asynchronously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it creates a threads pool of size `num_threads`.
|
||||
|
||||
@@ -944,7 +944,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def stop_image_writer(self) -> None:
|
||||
"""
|
||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
|
||||
@@ -101,7 +101,7 @@ def decode_video_frames_torchvision(
|
||||
keyframes_only = False
|
||||
torchvision.set_video_backend(backend)
|
||||
if backend == "pyav":
|
||||
keyframes_only = True # pyav doesnt support accuracte seek
|
||||
keyframes_only = True # pyav doesn't support accurate seek
|
||||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
|
||||
@@ -14,12 +14,10 @@
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@@ -156,122 +154,3 @@ class XarmEnv(EnvConfig):
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EEActionSpaceConfig:
|
||||
"""Configuration parameters for end-effector action space."""
|
||||
|
||||
x_step_size: float
|
||||
y_step_size: float
|
||||
z_step_size: float
|
||||
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||
control_mode: str = "gamepad"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvTransformConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
display_cameras: bool = False
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||
resize_size: Optional[Tuple[int, int]] = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Optional[Any] = None
|
||||
reset_time_s: float = 5.0
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
task: str = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("hil")
|
||||
@dataclass
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
type: str = "hil"
|
||||
name: str = "PandaPickCube"
|
||||
task: str = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
state_dim: int = 18
|
||||
action_dim: int = 4
|
||||
fps: int = 100
|
||||
episode_length: int = 100
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_ROBOT,
|
||||
}
|
||||
)
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
############################
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"use_viewer": self.use_viewer,
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "hil":
|
||||
return HILEnvConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
@@ -67,8 +65,5 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||
env = env_cls(
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
)
|
||||
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase
|
||||
# https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/vector/vectorize_observation.py#L19,
|
||||
# env = ObservationProcessorWrapper(env=env)
|
||||
|
||||
return env
|
||||
|
||||
@@ -47,10 +47,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
|
||||
# This is the case for human-in-the-loop RL where there is only one environment.
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
@@ -66,18 +62,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return_observations[imgkey] = img
|
||||
|
||||
if "environment_state" in observations:
|
||||
env_state = torch.from_numpy(observations["environment_state"]).float()
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
|
||||
return_observations["observation.environment_state"] = env_state
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
observations["environment_state"]
|
||||
).float()
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations["observation.state"] = agent_pos
|
||||
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
@@ -14,9 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
@@ -45,16 +44,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""
|
||||
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
||||
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
||||
For example, you can have one optimizer for the policy and another one for the value function
|
||||
in reinforcement learning settings.
|
||||
|
||||
Returns:
|
||||
The optimizer or a dictionary of optimizers.
|
||||
"""
|
||||
def build(self) -> torch.optim.Optimizer:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -104,76 +94,7 @@ class SGDConfig(OptimizerConfig):
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("multi_adam")
|
||||
@dataclass
|
||||
class MultiAdamConfig(OptimizerConfig):
|
||||
"""Configuration for multiple Adam optimizers with different parameter groups.
|
||||
|
||||
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
||||
|
||||
Args:
|
||||
lr: Default learning rate (used if not specified for a group)
|
||||
weight_decay: Default weight decay (used if not specified for a group)
|
||||
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
||||
grad_clip_norm: Gradient clipping norm
|
||||
"""
|
||||
|
||||
lr: float = 1e-3
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Build multiple Adam optimizers.
|
||||
|
||||
Args:
|
||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||
The keys should match the keys in optimizer_groups
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter group names to their optimizers
|
||||
"""
|
||||
optimizers = {}
|
||||
|
||||
for name, params in params_dict.items():
|
||||
# Get group-specific hyperparameters or use defaults
|
||||
group_config = self.optimizer_groups.get(name, {})
|
||||
|
||||
# Create optimizer with merged parameters (defaults + group-specific)
|
||||
optimizer_kwargs = {
|
||||
"lr": group_config.get("lr", self.lr),
|
||||
"betas": group_config.get("betas", (0.9, 0.999)),
|
||||
"eps": group_config.get("eps", 1e-5),
|
||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||
}
|
||||
|
||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def save_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> None:
|
||||
"""Save optimizer state to disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to save the optimizer state.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
||||
_save_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
# Handle single optimizer
|
||||
_save_single_optimizer_state(optimizer, save_dir)
|
||||
|
||||
|
||||
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
"""Save a single optimizer's state to disk."""
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
state = optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
@@ -181,44 +102,11 @@ def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""Load optimizer state from disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to load the optimizer state from.
|
||||
|
||||
Returns:
|
||||
The updated optimizer(s) with loaded state.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
loaded_optimizers = {}
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
if optimizer_dir.exists():
|
||||
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
loaded_optimizers[name] = opt
|
||||
return loaded_optimizers
|
||||
else:
|
||||
# Handle single optimizer
|
||||
return _load_single_optimizer_state(optimizer, save_dir)
|
||||
|
||||
|
||||
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
"""Load a single optimizer's state from disk."""
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
|
||||
# Handle case where 'state' key might not exist (for newly created optimizers)
|
||||
if "state" in state:
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
else:
|
||||
loaded_state_dict = {"state": {}}
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
||||
if "param_groups" in current_state_dict:
|
||||
param_groups = deserialize_json_into_object(
|
||||
|
||||
@@ -27,7 +27,6 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -60,14 +59,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy
|
||||
elif name == "reward_classifier":
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -85,8 +76,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -151,7 +151,6 @@ class Normalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
@@ -253,168 +252,3 @@ class Unnormalize(nn.Module):
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> None:
|
||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||
|
||||
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
|
||||
but is factored out so it can be reused by both classes and stay in sync.
|
||||
"""
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape: tuple[int, ...] = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# reduce spatial dimensions, keep channel dimension only
|
||||
c, *_ = shape
|
||||
shape = (c, 1, 1)
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
std_data = stats[key]["std"]
|
||||
if isinstance(mean_data, torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_mean", mean)
|
||||
module.register_buffer(f"{prefix}_std", std)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_min", min_val)
|
||||
module.register_buffer(f"{prefix}_max", max_val)
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
|
||||
class NormalizeBuffer(nn.Module):
|
||||
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeBuffer(nn.Module):
|
||||
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max_val - min_val) + min_val
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
@@ -357,7 +357,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
|
||||
@@ -516,7 +516,7 @@ class PI0FAST(nn.Module):
|
||||
interpolate_like_pi=self.config.interpolate_like_pi,
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@dataclass
|
||||
class RewardClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Reward Classifier model."""
|
||||
|
||||
name: str = "reward_classifier"
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
learning_rate: float = 1e-4
|
||||
weight_decay: float = 0.01
|
||||
grad_clip_norm: float = 1.0
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
grad_clip_norm=self.grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate feature configurations."""
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
if not has_image:
|
||||
raise ValueError(
|
||||
"You must provide an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
@@ -1,301 +0,0 @@
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_IMAGE
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logits: Tensor,
|
||||
probabilities: Optional[Tensor] = None,
|
||||
hidden_states: Optional[Tensor] = None,
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class SpatialLearnedEmbeddings(nn.Module):
|
||||
def __init__(self, height, width, channel, num_features=8):
|
||||
"""
|
||||
PyTorch implementation of learned spatial embeddings
|
||||
|
||||
Args:
|
||||
height: Spatial height of input features
|
||||
width: Spatial width of input features
|
||||
channel: Number of input channels
|
||||
num_features: Number of output embedding dimensions
|
||||
"""
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channel = channel
|
||||
self.num_features = num_features
|
||||
|
||||
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||
|
||||
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Forward pass for spatial embedding
|
||||
|
||||
Args:
|
||||
features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch
|
||||
Returns:
|
||||
Output tensor of shape [B, C*F] or [C*F] if no batch
|
||||
"""
|
||||
|
||||
features = features.last_hidden_state
|
||||
|
||||
original_shape = features.shape
|
||||
if features.dim() == 3:
|
||||
features = features.unsqueeze(0) # Add batch dim
|
||||
|
||||
features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1]
|
||||
kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F]
|
||||
|
||||
# Element-wise multiplication and spatial reduction
|
||||
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W
|
||||
|
||||
# Reshape to combine channel and feature dimensions
|
||||
output = output.view(output.size(0), -1) # [B, C*F]
|
||||
|
||||
# Remove batch dim
|
||||
if len(original_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Classifier(PreTrainedPolicy):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "reward_classifier"
|
||||
config_class = RewardClassifierConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
|
||||
# Extract image keys from input_features
|
||||
self.image_keys = [
|
||||
key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE)
|
||||
]
|
||||
|
||||
if self.is_cnn:
|
||||
self.encoders = nn.ModuleDict()
|
||||
for image_key in self.image_keys:
|
||||
encoder = self._create_single_encoder()
|
||||
self.encoders[image_key] = encoder
|
||||
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _create_single_encoder(self):
|
||||
encoder = nn.Sequential(
|
||||
self.encoder,
|
||||
SpatialLearnedEmbeddings(
|
||||
height=4,
|
||||
width=4,
|
||||
channel=self.feature_dim,
|
||||
num_features=self.config.image_embedding_pooling_dim,
|
||||
),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
return encoder
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.config.latent_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(
|
||||
self.config.hidden_dim,
|
||||
1 if self.config.num_classes == 2 else self.config.num_classes,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoders[image_key](x)
|
||||
return outputs
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(x)
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
|
||||
"""Extract image tensors and label tensors from batch."""
|
||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
labels = batch["next.reward"]
|
||||
|
||||
return images, labels
|
||||
|
||||
def predict(self, xs: list) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier for inference."""
|
||||
encoder_outputs = torch.hstack(
|
||||
[self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)]
|
||||
)
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
# Get predictions
|
||||
outputs = self.predict(images)
|
||||
|
||||
# Calculate loss
|
||||
if self.config.num_classes == 2:
|
||||
# Binary classification
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
# Multi-class classification
|
||||
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
|
||||
# Calculate accuracy for logging
|
||||
correct = (predictions == labels).sum().item()
|
||||
total = labels.size(0)
|
||||
accuracy = 100 * correct / total
|
||||
|
||||
# Return loss and metrics for logging
|
||||
output_dict = {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def predict_reward(self, batch, threshold=0.5):
|
||||
"""Eval method. Returns predicted reward with the decision threshold as argument."""
|
||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images from batch dict
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.predict(images).probabilities
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Return optimizer parameters for the policy."""
|
||||
return self.parameters()
|
||||
|
||||
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not select actions")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
pass
|
||||
@@ -1,243 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
def is_image_feature(key: str) -> bool:
|
||||
"""Check if a feature key represents an image feature.
|
||||
|
||||
Args:
|
||||
key: The feature key to check
|
||||
|
||||
Returns:
|
||||
True if the key represents an image feature, False otherwise
|
||||
"""
|
||||
return key.startswith("observation.image")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyConfig:
|
||||
"""Configuration for the concurrency of the actor and learner.
|
||||
Possible values are:
|
||||
- "threads": Use threads for the actor and learner.
|
||||
- "processes": Use processes for the actor and learner.
|
||||
"""
|
||||
|
||||
actor: str = "threads"
|
||||
learner: str = "threads"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
learner_port: int = 50051
|
||||
policy_parameters_push_frequency: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriticNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
final_activation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
log_std_min: float = 1e-5
|
||||
log_std_max: float = 10.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
"""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Statistics for normalizing different types of inputs
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture specifics
|
||||
# Device to run the model on (e.g., "cuda", "cpu")
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
# Hidden dimension size for the image encoder
|
||||
image_encoder_hidden_dim: int = 32
|
||||
# Whether to use a shared encoder for actor and critic
|
||||
shared_encoder: bool = True
|
||||
# Number of discrete actions, eg for gripper actions
|
||||
num_discrete_actions: int | None = None
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Seed for the online environment
|
||||
online_env_seed: int = 10000
|
||||
# Capacity of the online replay buffer
|
||||
online_buffer_capacity: int = 100000
|
||||
# Capacity of the offline replay buffer
|
||||
offline_buffer_capacity: int = 100000
|
||||
# Whether to use asynchronous prefetching for the buffers
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
# Discount factor for the SAC algorithm
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for actor-learner architecture
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
},
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
has_image = any(is_image_feature(key) for key in self.input_features)
|
||||
has_state = "observation.state" in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
raise ValueError(
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features if is_image_feature(key)]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return None # SAC typically predicts one action at a time
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -87,8 +87,6 @@ class RecordControlConfig(ControlConfig):
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Reset follower arms to an initial position.
|
||||
reset_follower_arms: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
|
||||
@@ -24,7 +24,6 @@ from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
@@ -130,11 +129,9 @@ def predict_action(observation, policy, device, use_amp):
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""
|
||||
Initializes a keyboard listener to enable early termination of an episode
|
||||
or environment reset by pressing the right arrow key ('->'). This may require
|
||||
sudo permissions to allow the terminal to monitor keyboard events.
|
||||
"""
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
@@ -163,7 +160,6 @@ def init_keyboard_listener():
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
@@ -247,6 +243,11 @@ def control_loop(
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
# Controls starts, if policy is given it needs cleaning up
|
||||
if policy is not None:
|
||||
policy.reset()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
@@ -254,13 +255,11 @@ def control_loop(
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
action = None
|
||||
|
||||
if policy is not None:
|
||||
pred_action = predict_action(
|
||||
observation,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
)
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
@@ -273,9 +272,10 @@ def control_loop(
|
||||
|
||||
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
|
||||
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
|
||||
for k, v in action.items():
|
||||
for i, vv in enumerate(v):
|
||||
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
||||
if action is not None:
|
||||
for k, v in action.items():
|
||||
for i, vv in enumerate(v):
|
||||
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
||||
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
@@ -308,17 +308,7 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||
)
|
||||
|
||||
|
||||
def reset_follower_position(robot_arm, target_position):
|
||||
current_position = robot_arm.read("Present_Position")
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 50)
|
||||
) # NOTE: 30 is just an arbitrary number
|
||||
for pose in trajectory:
|
||||
robot_arm.write("Goal_Position", pose)
|
||||
busy_wait(0.015)
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
def stop_recording(robot, listener, display_data):
|
||||
robot.disconnect()
|
||||
|
||||
if not is_headless() and listener is not None:
|
||||
@@ -344,20 +334,12 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset,
|
||||
robot: Robot,
|
||||
fps: int,
|
||||
use_videos: bool,
|
||||
extra_features: dict = None,
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||||
) -> None:
|
||||
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||
if extra_features is not None:
|
||||
features_from_robot.update(extra_features)
|
||||
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
("features", dataset.features, features_from_robot),
|
||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
|
||||
@@ -18,11 +18,9 @@ import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import subprocess
|
||||
import time
|
||||
from copy import copy, deepcopy
|
||||
from copy import copy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -109,17 +107,11 @@ def is_amp_available(device: str):
|
||||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||
def init_logging():
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
|
||||
# NOTE: Display PID is useful for multi-process logging.
|
||||
if display_pid:
|
||||
pid_str = f"[PID: {os.getpid()}]"
|
||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
else:
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -133,12 +125,6 @@ def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
if log_file is not None:
|
||||
# Additionally write logs to file
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
@@ -242,114 +228,3 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||
except TypeError:
|
||||
# If a TypeError is raised, the string is not a valid dtype
|
||||
return False
|
||||
|
||||
|
||||
class TimerManager:
|
||||
"""
|
||||
Lightweight utility to measure elapsed time.
|
||||
|
||||
Examples
|
||||
--------
|
||||
```python
|
||||
# Example 1: Using context manager
|
||||
timer = TimerManager("Policy", log=False)
|
||||
for _ in range(3):
|
||||
with timer:
|
||||
time.sleep(0.01)
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
|
||||
```python
|
||||
# Example 2: Using start/stop methods
|
||||
timer = TimerManager("Policy", log=False)
|
||||
timer.start()
|
||||
time.sleep(0.01)
|
||||
timer.stop()
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str = "Elapsed-time",
|
||||
log: bool = True,
|
||||
logger: logging.Logger | None = None,
|
||||
):
|
||||
self.label = label
|
||||
self.log = log
|
||||
self.logger = logger
|
||||
self._start: float | None = None
|
||||
self._history: list[float] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self.start()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
self._start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def stop(self) -> float:
|
||||
if self._start is None:
|
||||
raise RuntimeError("Timer was never started.")
|
||||
elapsed = time.perf_counter() - self._start
|
||||
self._history.append(elapsed)
|
||||
self._start = None
|
||||
if self.log:
|
||||
if self.logger is not None:
|
||||
self.logger.info(f"{self.label}: {elapsed:.6f} s")
|
||||
else:
|
||||
logging.info(f"{self.label}: {elapsed:.6f} s")
|
||||
return elapsed
|
||||
|
||||
def reset(self):
|
||||
self._history.clear()
|
||||
|
||||
@property
|
||||
def last(self) -> float:
|
||||
return self._history[-1] if self._history else 0.0
|
||||
|
||||
@property
|
||||
def avg(self) -> float:
|
||||
return mean(self._history) if self._history else 0.0
|
||||
|
||||
@property
|
||||
def total(self) -> float:
|
||||
return sum(self._history)
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self._history)
|
||||
|
||||
@property
|
||||
def history(self) -> list[float]:
|
||||
return deepcopy(self._history)
|
||||
|
||||
@property
|
||||
def fps_history(self) -> list[float]:
|
||||
return [1.0 / t for t in self._history]
|
||||
|
||||
@property
|
||||
def fps_last(self) -> float:
|
||||
return 0.0 if self.last == 0 else 1.0 / self.last
|
||||
|
||||
@property
|
||||
def fps_avg(self) -> float:
|
||||
return 0.0 if self.avg == 0 else 1.0 / self.avg
|
||||
|
||||
def percentile(self, p: float) -> float:
|
||||
"""
|
||||
Return the p-th percentile of recorded times.
|
||||
"""
|
||||
if not self._history:
|
||||
return 0.0
|
||||
return float(np.percentile(self._history, p))
|
||||
|
||||
def fps_percentile(self, p: float) -> float:
|
||||
"""
|
||||
FPS corresponding to the p-th percentile time.
|
||||
"""
|
||||
val = self.percentile(p)
|
||||
return 0.0 if val == 0 else 1.0 / val
|
||||
|
||||
@@ -30,10 +30,9 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"dataset:{cfg.dataset.repo_id}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.dataset is not None:
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
return lst if return_list else "-".join(lst)
|
||||
@@ -93,12 +92,6 @@ class WandBLogger:
|
||||
resume="must" if cfg.resume else None,
|
||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||
)
|
||||
run_id = wandb.run.id
|
||||
# NOTE: We will override the cfg.wandb.run_id with the wandb run id.
|
||||
# This is because we want to be able to resume the run from the wandb run id.
|
||||
cfg.wandb.run_id = run_id
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
@@ -115,26 +108,9 @@ class WandBLogger:
|
||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(
|
||||
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
|
||||
):
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
# NOTE: This is not simple. Wandb step must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
# multiple time steps is possible. For example, the interaction step with the environment,
|
||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||
# to log the correct step for each metric.
|
||||
if custom_step_key is not None:
|
||||
if self._wandb_custom_step_key is None:
|
||||
self._wandb_custom_step_key = set()
|
||||
new_custom_key = f"{mode}/{custom_step_key}"
|
||||
if new_custom_key not in self._wandb_custom_step_key:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
@@ -142,18 +118,7 @@ class WandBLogger:
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||
self._wandb.log(data)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
|
||||
@@ -34,10 +34,11 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||
dataset: DatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||
@@ -106,7 +107,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if self.dataset is not None and isinstance(self.dataset.repo_id, list):
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
|
||||
@@ -23,7 +23,6 @@ class FeatureType(str, Enum):
|
||||
VISUAL = "VISUAL"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
|
||||
@@ -1,722 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Actor server runner for distributed HILSerl robot policy training.
|
||||
|
||||
This script implements the actor component of the distributed HILSerl architecture.
|
||||
It executes the policy in the robot environment, collects experience,
|
||||
and sends transitions to the learner server for policy updates.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Start an actor server for real robot training with human-in-the-loop intervention:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
- Run with a specific robot type for a pick and place task:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--robot.type=so100 \
|
||||
--task=pick_and_place
|
||||
```
|
||||
|
||||
- Set a custom workspace bound for the robot's end-effector:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
|
||||
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
|
||||
```
|
||||
|
||||
- Run with specific camera crop parameters:
|
||||
```bash
|
||||
python lerobot/scripts/server/actor_server.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
|
||||
```
|
||||
|
||||
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
|
||||
server is started before launching the actor.
|
||||
|
||||
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
|
||||
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
|
||||
reduce interventions as the policy improves.
|
||||
|
||||
**WORKFLOW**:
|
||||
1. Determine robot workspace bounds using `find_joint_limits.py`
|
||||
2. Record demonstrations with `gym_manipulator.py` in record mode
|
||||
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
|
||||
4. Start the learner server with the training configuration
|
||||
5. Start this actor server with the same configuration
|
||||
6. Use human interventions to guide policy learning
|
||||
|
||||
For more details on the complete HILSerl training workflow, see:
|
||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
||||
from lerobot.scripts.server.buffer import Transition
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
bytes_to_state_dict,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.utils import (
|
||||
get_last_item_from_queue,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
setup_process_handlers,
|
||||
)
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def actor_cli(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
display_pid = False
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
display_pid = True
|
||||
|
||||
# Create logs directory to ensure it exists
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=display_pid)
|
||||
logging.info(f"Actor logging initialized, writing to {log_file}")
|
||||
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
logging.info("[ACTOR] Establishing connection with Learner")
|
||||
if not establish_learner_connection(learner_client, shutdown_event):
|
||||
logging.error("[ACTOR] Failed to establish connection with Learner")
|
||||
return
|
||||
|
||||
if not use_threads(cfg):
|
||||
# If we use multithreading, we can reuse the channel
|
||||
grpc_channel.close()
|
||||
grpc_channel = None
|
||||
|
||||
logging.info("[ACTOR] Connection with Learner established")
|
||||
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
|
||||
concurrency_entity = None
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
receive_policy_process = concurrency_entity(
|
||||
target=receive_policy,
|
||||
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process = concurrency_entity(
|
||||
target=send_transitions,
|
||||
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
interactions_process = concurrency_entity(
|
||||
target=send_interactions,
|
||||
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process.start()
|
||||
interactions_process.start()
|
||||
receive_policy_process.start()
|
||||
|
||||
act_with_policy(
|
||||
cfg=cfg,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
transitions_queue=transitions_queue,
|
||||
interactions_queue=interactions_queue,
|
||||
)
|
||||
logging.info("[ACTOR] Policy process joined")
|
||||
|
||||
logging.info("[ACTOR] Closing queues")
|
||||
transitions_queue.close()
|
||||
interactions_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
transitions_process.join()
|
||||
logging.info("[ACTOR] Transitions process joined")
|
||||
interactions_process.join()
|
||||
logging.info("[ACTOR] Interactions process joined")
|
||||
receive_policy_process.join()
|
||||
logging.info("[ACTOR] Receive policy process joined")
|
||||
|
||||
logging.info("[ACTOR] join queues")
|
||||
transitions_queue.cancel_join_thread()
|
||||
interactions_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def act_with_policy(
|
||||
cfg: TrainPipelineConfig,
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
|
||||
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
|
||||
|
||||
Args:
|
||||
cfg: Configuration settings for the interaction process.
|
||||
shutdown_event: Event to check if the process should shutdown.
|
||||
parameters_queue: Queue to receive updated network parameters from the learner.
|
||||
transitions_queue: Queue to send transitions to the learner.
|
||||
interactions_queue: Queue to send interactions to the learner.
|
||||
"""
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor policy process logging initialized")
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
episode_intervention = False
|
||||
# Add counters for intervention rate calculation
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
policy_timer = TimerManager("Policy inference", log=False)
|
||||
|
||||
for interaction_step in range(cfg.policy.online_steps):
|
||||
start_time = time.perf_counter()
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
transitions=list_transition_to_send_to_learner,
|
||||
transitions_queue=transitions_queue,
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = get_frequency_stats(policy_timer)
|
||||
policy_timer.reset()
|
||||
|
||||
# Calculate intervention rate
|
||||
intervention_rate = 0.0
|
||||
if episode_total_steps > 0:
|
||||
intervention_rate = episode_intervention_steps / episode_total_steps
|
||||
|
||||
# Send episodic reward to the learner
|
||||
interactions_queue.put(
|
||||
python_object_to_bytes(
|
||||
{
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
"Episode intervention": int(episode_intervention),
|
||||
"Intervention rate": intervention_rate,
|
||||
**stats,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
obs, info = online_env.reset()
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub: hilserl_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event, # type: ignore
|
||||
attempts: int = 30,
|
||||
):
|
||||
"""Establish a connection with the learner.
|
||||
|
||||
Args:
|
||||
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
|
||||
shutdown_event (Event): The event to check if the connection should be established.
|
||||
attempts (int): The number of attempts to establish the connection.
|
||||
Returns:
|
||||
bool: True if the connection is established, False otherwise.
|
||||
"""
|
||||
for _ in range(attempts):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||
return False
|
||||
|
||||
# Force a connection attempt and check state
|
||||
try:
|
||||
logging.info("[ACTOR] Send ready message to Learner")
|
||||
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
|
||||
return True
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
import json
|
||||
|
||||
"""
|
||||
Returns a client for the learner service.
|
||||
|
||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||
So we need to create only one client and reuse it.
|
||||
"""
|
||||
|
||||
service_config = {
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||
"maxBackoff": "2s", # Max wait time between retries
|
||||
"backoffMultiplier": 2, # Exponential backoff factor
|
||||
"retryableStatusCodes": [
|
||||
"UNAVAILABLE",
|
||||
"DEADLINE_EXCEEDED",
|
||||
], # Retries on network failures
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service_config_json = json.dumps(service_config)
|
||||
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.enable_retries", 1),
|
||||
("grpc.service_config", service_config_json),
|
||||
],
|
||||
)
|
||||
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
|
||||
logging.info("[ACTOR] Learner service client created")
|
||||
return stub, channel
|
||||
|
||||
|
||||
def receive_policy(
|
||||
cfg: TrainPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
"""Receive parameters from the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The configuration for the actor.
|
||||
parameters_queue (Queue): The queue to receive the parameters.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
"""
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor receive policy process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(use_threads=False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
|
||||
receive_bytes_in_chunks(
|
||||
iterator,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
log_prefix="[ACTOR] parameters",
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Received policy loop stopped")
|
||||
|
||||
|
||||
def send_transitions(
|
||||
cfg: TrainPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> hilserl_pb2.Empty:
|
||||
"""
|
||||
Sends transitions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Transition Data:
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor transitions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming transitions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Transitions process stopped")
|
||||
|
||||
|
||||
def send_interactions(
|
||||
cfg: TrainPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> hilserl_pb2.Empty:
|
||||
"""
|
||||
Sends interactions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Interaction Messages:
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor interactions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming interactions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=5)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Transition queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
||||
)
|
||||
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: Event, # type: ignore
|
||||
interactions_queue: Queue,
|
||||
) -> hilserl_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = interactions_queue.get(block=True, timeout=5)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Interaction queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message,
|
||||
hilserl_pb2.InteractionMessage,
|
||||
log_prefix="[ACTOR] Send interactions",
|
||||
)
|
||||
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
#################################################
|
||||
# Policy functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
if not parameters_queue.empty():
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue)
|
||||
state_dict = bytes_to_state_dict(bytes_state_dict)
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||
|
||||
Args:
|
||||
transitions: List of transitions to send
|
||||
message_queue: Queue to send messages to learner
|
||||
chunk_size: Size of each chunk to send
|
||||
"""
|
||||
transition_to_send_to_learner = []
|
||||
for transition in transitions:
|
||||
tr = move_transition_to_device(transition=transition, device="cpu")
|
||||
for key, value in tr["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
|
||||
transition_to_send_to_learner.append(tr)
|
||||
|
||||
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||
|
||||
|
||||
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
|
||||
"""Get the frequency statistics of the policy.
|
||||
|
||||
Args:
|
||||
timer (TimerManager): The timer with collected metrics.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The frequency statistics of the policy.
|
||||
"""
|
||||
stats = {}
|
||||
if timer.count > 1:
|
||||
avg_fps = timer.fps_avg
|
||||
p90_fps = timer.fps_percentile(90)
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
|
||||
stats = {
|
||||
"Policy frequency [Hz]": avg_fps,
|
||||
"Policy frequency 90th-p [Hz]": p90_fps,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
|
||||
if policy_fps < cfg.env.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
|
||||
)
|
||||
|
||||
|
||||
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.actor == "threads"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
actor_cli()
|
||||
@@ -1,820 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from contextlib import suppress
|
||||
from typing import Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.scripts.server.utils import Transition
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
truncated: torch.Tensor
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
(Same as shown previously.)
|
||||
"""
|
||||
B, C, H, W = images.shape # noqa: N806
|
||||
crop_h, crop_w = output_size
|
||||
|
||||
if crop_h > H or crop_w > W:
|
||||
raise ValueError(
|
||||
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||
)
|
||||
|
||||
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||
|
||||
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||
|
||||
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
return cropped
|
||||
|
||||
|
||||
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||
_, _, h, w = images.shape
|
||||
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
):
|
||||
"""
|
||||
Replay buffer for storing transitions.
|
||||
It will allocate tensors on the specified device, when the first transition is added.
|
||||
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
|
||||
and use the `storage_device` flag to store the buffer on a different device.
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
||||
Using "cpu" can help save GPU memory.
|
||||
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
|
||||
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
|
||||
"""
|
||||
if capacity <= 0:
|
||||
raise ValueError("Capacity must be greater than 0.")
|
||||
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
self.position = 0
|
||||
self.size = 0
|
||||
self.initialized = False
|
||||
self.optimize_memory = optimize_memory
|
||||
|
||||
# Track episode boundaries for memory optimization
|
||||
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
|
||||
self.image_augmentation_function = image_augmentation_function
|
||||
|
||||
if image_augmentation_function is None:
|
||||
base_function = functools.partial(random_shift, pad=4)
|
||||
self.image_augmentation_function = torch.compile(base_function)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def _initialize_storage(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Initialize the storage tensors based on the first transition."""
|
||||
# Determine shapes from the first transition
|
||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||
action_shape = action.squeeze(0).shape
|
||||
|
||||
# Pre-allocate tensors for storage
|
||||
self.states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
|
||||
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach: store states and next_states separately
|
||||
self.next_states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
else:
|
||||
# Memory-optimized approach: don't allocate next_states buffer
|
||||
# Just create a reference to states for consistent API
|
||||
self.next_states = self.states # Just a reference for API consistency
|
||||
|
||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
|
||||
# Initialize storage for complementary_info
|
||||
self.has_complementary_info = complementary_info is not None
|
||||
self.complementary_info_keys = []
|
||||
self.complementary_info = {}
|
||||
|
||||
if self.has_complementary_info:
|
||||
self.complementary_info_keys = list(complementary_info.keys())
|
||||
# Pre-allocate tensors for each key in complementary_info
|
||||
for key, value in complementary_info.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
value_shape = value.squeeze(0).shape
|
||||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def add(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
truncated: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Initialize storage if this is the first transition
|
||||
if not self.initialized:
|
||||
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
||||
|
||||
# Store the transition in pre-allocated tensors
|
||||
for key in self.states:
|
||||
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Only store next_states if not optimizing memory
|
||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
|
||||
# Handle complementary_info if provided and storage is initialized
|
||||
if complementary_info is not None and self.has_complementary_info:
|
||||
# Store the complementary_info
|
||||
for key in self.complementary_info_keys:
|
||||
if key in complementary_info:
|
||||
value = complementary_info[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||
elif isinstance(value, (int, float)):
|
||||
self.complementary_info[key][self.position] = value
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
if not self.initialized:
|
||||
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
||||
|
||||
batch_size = min(batch_size, self.size)
|
||||
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
||||
|
||||
# Random indices for sampling - create on the same device as storage
|
||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
batch_next_state = {}
|
||||
|
||||
# First pass: load all state tensors to target device
|
||||
for key in self.states:
|
||||
batch_state[key] = self.states[key][idx].to(self.device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach - load next_states directly
|
||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||
else:
|
||||
# Memory-optimized approach - get next_state from the next index
|
||||
next_idx = (idx + 1) % self.capacity
|
||||
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
||||
|
||||
# Apply image augmentation in a batched way if needed
|
||||
if self.use_drq and image_keys:
|
||||
# Concatenate all images from state and next_state
|
||||
all_images = []
|
||||
for key in image_keys:
|
||||
all_images.append(batch_state[key])
|
||||
all_images.append(batch_next_state[key])
|
||||
|
||||
# Optimization: Batch all images and apply augmentation once
|
||||
all_images_tensor = torch.cat(all_images, dim=0)
|
||||
augmented_images = self.image_augmentation_function(all_images_tensor)
|
||||
|
||||
# Split the augmented images back to their sources
|
||||
for i, key in enumerate(image_keys):
|
||||
# Calculate offsets for the current image key:
|
||||
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
|
||||
# States start at index i*2*batch_size and take up batch_size slots
|
||||
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
|
||||
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||
|
||||
# Sample other tensors
|
||||
batch_actions = self.actions[idx].to(self.device)
|
||||
batch_rewards = self.rewards[idx].to(self.device)
|
||||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
|
||||
# Sample complementary_info if available
|
||||
batch_complementary_info = None
|
||||
if self.has_complementary_info:
|
||||
batch_complementary_info = {}
|
||||
for key in self.complementary_info_keys:
|
||||
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
||||
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
complementary_info=batch_complementary_info,
|
||||
)
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""
|
||||
Creates an infinite iterator that yields batches of transitions.
|
||||
Will automatically restart when internal iterator is exhausted.
|
||||
|
||||
Args:
|
||||
batch_size (int): Size of batches to sample
|
||||
async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True)
|
||||
queue_size (int): Number of batches to prefetch (default: 2)
|
||||
|
||||
Yields:
|
||||
BatchTransition: Batched transitions
|
||||
"""
|
||||
while True: # Create an infinite loop
|
||||
if async_prefetch:
|
||||
# Get the standard iterator
|
||||
iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size)
|
||||
else:
|
||||
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
|
||||
|
||||
# Yield all items from the iterator
|
||||
with suppress(StopIteration):
|
||||
yield from iterator
|
||||
|
||||
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
|
||||
"""
|
||||
Creates an iterator that prefetches batches in a background thread.
|
||||
|
||||
Args:
|
||||
queue_size (int): Number of batches to prefetch (default: 2)
|
||||
batch_size (int): Size of batches to sample (default: 128)
|
||||
|
||||
Yields:
|
||||
BatchTransition: Prefetched batch transitions
|
||||
"""
|
||||
import queue
|
||||
import threading
|
||||
|
||||
# Use thread-safe queue
|
||||
data_queue = queue.Queue(maxsize=queue_size)
|
||||
running = [True] # Use list to allow modification in nested function
|
||||
|
||||
def prefetch_worker():
|
||||
while running[0]:
|
||||
try:
|
||||
# Sample data and add to queue
|
||||
data = self.sample(batch_size)
|
||||
data_queue.put(data, block=True, timeout=0.5)
|
||||
except queue.Full:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Prefetch error: {e}")
|
||||
break
|
||||
|
||||
# Start prefetching thread
|
||||
thread = threading.Thread(target=prefetch_worker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
while running[0]:
|
||||
try:
|
||||
yield data_queue.get(block=True, timeout=0.5)
|
||||
except queue.Empty:
|
||||
if not thread.is_alive():
|
||||
break
|
||||
finally:
|
||||
# Clean up
|
||||
running[0] = False
|
||||
thread.join(timeout=1.0)
|
||||
|
||||
def _get_naive_iterator(self, batch_size: int, queue_size: int = 2):
|
||||
"""
|
||||
Creates a simple non-threaded iterator that yields batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): Size of batches to sample
|
||||
queue_size (int): Number of initial batches to prefetch
|
||||
|
||||
Yields:
|
||||
BatchTransition: Batch transitions
|
||||
"""
|
||||
import collections
|
||||
|
||||
queue = collections.deque()
|
||||
|
||||
def enqueue(n):
|
||||
for _ in range(n):
|
||||
data = self.sample(batch_size)
|
||||
queue.append(data)
|
||||
|
||||
enqueue(queue_size)
|
||||
while queue:
|
||||
yield queue.popleft()
|
||||
enqueue(1)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device for sampling tensors. Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
|
||||
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
|
||||
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
|
||||
image_augmentation_function (Optional[Callable]): Function for image augmentation.
|
||||
If None, uses default random shift with pad=4.
|
||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
|
||||
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with dataset transitions.
|
||||
"""
|
||||
if capacity is None:
|
||||
capacity = len(lerobot_dataset)
|
||||
|
||||
if capacity < len(lerobot_dataset):
|
||||
raise ValueError(
|
||||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||
)
|
||||
|
||||
# Create replay buffer with image augmentation and DrQ settings
|
||||
replay_buffer = cls(
|
||||
capacity=capacity,
|
||||
device=device,
|
||||
state_keys=state_keys,
|
||||
image_augmentation_function=image_augmentation_function,
|
||||
use_drq=use_drq,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=optimize_memory,
|
||||
)
|
||||
|
||||
# Convert dataset to transitions
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
|
||||
# Initialize the buffer with the first transition to set up storage tensors
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition["action"].to(device)
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
if (
|
||||
"complementary_info" in first_transition
|
||||
and first_transition["complementary_info"] is not None
|
||||
):
|
||||
first_complementary_info = {
|
||||
k: v.to(device) for k, v in first_transition["complementary_info"].items()
|
||||
}
|
||||
|
||||
replay_buffer._initialize_storage(
|
||||
state=first_state, action=first_action, complementary_info=first_complementary_info
|
||||
)
|
||||
|
||||
# Fill the buffer with all transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
for key, tensor in v.items():
|
||||
v[key] = tensor.to(storage_device)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(storage_device)
|
||||
|
||||
action = data["action"]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=action,
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
||||
complementary_info=data.get("complementary_info", None),
|
||||
)
|
||||
|
||||
return replay_buffer
|
||||
|
||||
def to_lerobot_dataset(
|
||||
self,
|
||||
repo_id: str,
|
||||
fps=1,
|
||||
root=None,
|
||||
task_name="from_replay_buffer",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object.
|
||||
"""
|
||||
if self.size == 0:
|
||||
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
|
||||
|
||||
# Create features dictionary for the dataset
|
||||
features = {
|
||||
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
|
||||
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
|
||||
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
|
||||
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
|
||||
"task_index": {"dtype": "int64", "shape": [1]},
|
||||
}
|
||||
|
||||
# Add "action"
|
||||
sample_action = self.actions[0]
|
||||
act_info = guess_feature_info(t=sample_action, name="action")
|
||||
features["action"] = act_info
|
||||
|
||||
# Add "reward" and "done"
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.states:
|
||||
sample_val = self.states[key][0]
|
||||
f_info = guess_feature_info(t=sample_val, name=key)
|
||||
features[key] = f_info
|
||||
|
||||
# Add complementary_info keys if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
sample_val = self.complementary_info[key][0]
|
||||
if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
|
||||
sample_val = sample_val.unsqueeze(0)
|
||||
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
|
||||
features[f"complementary_info.{key}"] = f_info
|
||||
|
||||
# Create an empty LeRobotDataset
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
root=root,
|
||||
robot=None, # TODO: (azouitine) Handle robot
|
||||
robot_type=None,
|
||||
features=features,
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
# Start writing images if needed
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||
|
||||
# Convert transitions into episodes and frames
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for idx in range(self.size):
|
||||
actual_idx = (self.position - self.size + idx) % self.capacity
|
||||
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
for key in self.states:
|
||||
frame_dict[key] = self.states[key][actual_idx].cpu()
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
|
||||
# Add complementary_info if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
val = self.complementary_info[key][actual_idx]
|
||||
# Convert tensors to CPU
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.ndim == 0:
|
||||
val = val.unsqueeze(0)
|
||||
frame_dict[f"complementary_info.{key}"] = val.cpu()
|
||||
# Non-tensor values can be used directly
|
||||
else:
|
||||
frame_dict[f"complementary_info.{key}"] = val
|
||||
|
||||
# Add task field which is required by LeRobotDataset
|
||||
frame_dict["task"] = task_name
|
||||
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
||||
lerobot_dataset.save_episode()
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index=episode_index
|
||||
)
|
||||
|
||||
# Save any remaining frames in the buffer
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode()
|
||||
|
||||
lerobot_dataset.stop_image_writer()
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> list[Transition]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
The dataset to convert. Each item in the dataset is expected to have
|
||||
at least the following keys:
|
||||
{
|
||||
"action": ...
|
||||
"next.reward": ...
|
||||
"next.done": ...
|
||||
"episode_index": ...
|
||||
}
|
||||
plus whatever your 'state_keys' specify.
|
||||
|
||||
state_keys (Optional[Sequence[str]]):
|
||||
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||
will be kept as-is in the output transitions. E.g.
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
if state_keys is None:
|
||||
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
||||
|
||||
transitions = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
# Check if the dataset has "next.done" key
|
||||
sample = dataset[0]
|
||||
has_done_key = "next.done" in sample
|
||||
|
||||
# Check for complementary_info keys
|
||||
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
|
||||
has_complementary_info = len(complementary_info_keys) > 0
|
||||
|
||||
# If not, we need to infer it from episode boundaries
|
||||
if not has_done_key:
|
||||
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
current_state: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = current_sample[key]
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
|
||||
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
|
||||
if has_done_key:
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
else:
|
||||
# If this is the last frame or if next frame is in a different episode, mark as done
|
||||
done = False
|
||||
if i == num_frames - 1:
|
||||
done = True
|
||||
elif i < num_frames - 1:
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] != current_sample["episode_index"]:
|
||||
done = True
|
||||
|
||||
# TODO: (azouitine) Handle truncation (using the same value as done for now)
|
||||
truncated = done
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||
next_state = current_state # default
|
||||
if not done and (i < num_frames - 1):
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||
# Build next_state from the same keys
|
||||
next_state_data: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = next_sample[key]
|
||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- 5) Complementary info (if available) -----
|
||||
complementary_info = None
|
||||
if has_complementary_info:
|
||||
complementary_info = {}
|
||||
for key in complementary_info_keys:
|
||||
# Strip the "complementary_info." prefix to get the actual key
|
||||
clean_key = key[len("complementary_info.") :]
|
||||
val = current_sample[key]
|
||||
# Handle tensor and non-tensor values differently
|
||||
if isinstance(val, torch.Tensor):
|
||||
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
||||
else:
|
||||
# TODO: (azouitine) Check if it's necessary to convert to tensor
|
||||
# For non-tensor values, use directly
|
||||
complementary_info[clean_key] = val
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
|
||||
|
||||
# Utility function to guess shapes/dtypes from a tensor
|
||||
def guess_feature_info(t, name: str):
|
||||
"""
|
||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
|
||||
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
||||
Otherwise default to appropriate dtype for numeric.
|
||||
"""
|
||||
|
||||
shape = tuple(t.shape)
|
||||
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
||||
if len(shape) == 3 and shape[0] in [1, 3]:
|
||||
return {
|
||||
"dtype": "image",
|
||||
"shape": shape,
|
||||
}
|
||||
else:
|
||||
# Otherwise treat as numeric
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": shape,
|
||||
}
|
||||
|
||||
|
||||
def concatenate_batch_transitions(
|
||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
# Concatenate state fields
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["state"][key], right_batch_transition["state"][key]],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
|
||||
# Concatenate basic fields
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
|
||||
# Concatenate next_state fields
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
|
||||
# Concatenate done and truncated fields
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
left_batch_transitions["truncated"] = torch.cat(
|
||||
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Handle complementary_info
|
||||
left_info = left_batch_transitions.get("complementary_info")
|
||||
right_info = right_batch_transition.get("complementary_info")
|
||||
|
||||
# Only process if right_info exists
|
||||
if right_info is not None:
|
||||
# Initialize left complementary_info if needed
|
||||
if left_info is None:
|
||||
left_batch_transitions["complementary_info"] = right_info
|
||||
else:
|
||||
# Concatenate each field
|
||||
for key in right_info:
|
||||
if key in left_info:
|
||||
left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
|
||||
else:
|
||||
left_info[key] = right_info[key]
|
||||
|
||||
return left_batch_transitions
|
||||
@@ -1,303 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import cv2
|
||||
|
||||
# import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def select_rect_roi(img):
|
||||
"""
|
||||
Allows the user to draw a rectangular ROI on the image.
|
||||
|
||||
The user must click and drag to draw the rectangle.
|
||||
- While dragging, the rectangle is dynamically drawn.
|
||||
- On mouse button release, the rectangle is fixed.
|
||||
- Press 'c' to confirm the selection.
|
||||
- Press 'r' to reset the selection.
|
||||
- Press ESC to cancel.
|
||||
|
||||
Returns:
|
||||
A tuple (top, left, height, width) representing the rectangular ROI,
|
||||
or None if no valid ROI is selected.
|
||||
"""
|
||||
# Create a working copy of the image
|
||||
clone = img.copy()
|
||||
working_img = clone.copy()
|
||||
|
||||
roi = None # Will store the final ROI as (top, left, height, width)
|
||||
drawing = False
|
||||
index_x, index_y = -1, -1 # Initial click coordinates
|
||||
|
||||
def mouse_callback(event, x, y, flags, param):
|
||||
nonlocal index_x, index_y, drawing, roi, working_img
|
||||
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Start drawing: record starting coordinates
|
||||
drawing = True
|
||||
index_x, index_y = x, y
|
||||
|
||||
elif event == cv2.EVENT_MOUSEMOVE:
|
||||
if drawing:
|
||||
# Compute the top-left and bottom-right corners regardless of drag direction
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
# Show a temporary image with the current rectangle drawn
|
||||
temp = working_img.copy()
|
||||
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", temp)
|
||||
|
||||
elif event == cv2.EVENT_LBUTTONUP:
|
||||
# Finish drawing
|
||||
drawing = False
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
height = bottom - top
|
||||
width = right - left
|
||||
roi = (top, left, height, width) # (top, left, height, width)
|
||||
# Draw the final rectangle on the working image and display it
|
||||
working_img = clone.copy()
|
||||
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
# Create the window and set the callback
|
||||
cv2.namedWindow("Select ROI")
|
||||
cv2.setMouseCallback("Select ROI", mouse_callback)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
print("Instructions for ROI selection:")
|
||||
print(" - Click and drag to draw a rectangular ROI.")
|
||||
print(" - Press 'c' to confirm the selection.")
|
||||
print(" - Press 'r' to reset and draw again.")
|
||||
print(" - Press ESC to cancel the selection.")
|
||||
|
||||
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
|
||||
while True:
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
# Confirm ROI if one has been drawn
|
||||
if key == ord("c") and roi is not None:
|
||||
break
|
||||
# Reset: clear the ROI and restore the original image
|
||||
elif key == ord("r"):
|
||||
working_img = clone.copy()
|
||||
roi = None
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
# Cancel selection for this image
|
||||
elif key == 27: # ESC key
|
||||
roi = None
|
||||
break
|
||||
|
||||
cv2.destroyWindow("Select ROI")
|
||||
return roi
|
||||
|
||||
|
||||
def select_square_roi_for_images(images: dict) -> dict:
|
||||
"""
|
||||
For each image in the provided dictionary, open a window to allow the user
|
||||
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
|
||||
(top, left, height, width) representing the ROI.
|
||||
|
||||
Parameters:
|
||||
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
|
||||
|
||||
Returns:
|
||||
dict: Mapping of image keys to the selected rectangular ROI.
|
||||
"""
|
||||
selected_rois = {}
|
||||
|
||||
for key, img in images.items():
|
||||
if img is None:
|
||||
print(f"Image for key '{key}' is None, skipping.")
|
||||
continue
|
||||
|
||||
print(f"\nSelect rectangular ROI for image with key: '{key}'")
|
||||
roi = select_rect_roi(img)
|
||||
|
||||
if roi is None:
|
||||
print(f"No valid ROI selected for '{key}'.")
|
||||
else:
|
||||
selected_rois[key] = roi
|
||||
print(f"ROI for '{key}': {roi}")
|
||||
|
||||
return selected_rois
|
||||
|
||||
|
||||
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
||||
"""
|
||||
Find the first row in the dataset and extract the image in order to be used for the crop.
|
||||
"""
|
||||
row = dataset[0]
|
||||
image_dict = {}
|
||||
for k in row:
|
||||
if "image" in k:
|
||||
image_dict[k] = deepcopy(row[k])
|
||||
return image_dict
|
||||
|
||||
|
||||
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset: LeRobotDataset,
|
||||
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
|
||||
new_repo_id: str,
|
||||
new_dataset_root: str,
|
||||
resize_size: Tuple[int, int] = (128, 128),
|
||||
push_to_hub: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts an existing LeRobotDataset by iterating over its episodes and frames,
|
||||
applying cropping and resizing to image observations, and saving a new dataset
|
||||
with the transformed data.
|
||||
|
||||
Args:
|
||||
original_dataset (LeRobotDataset): The source dataset.
|
||||
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
||||
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
||||
new_repo_id (str): Repository id for the new dataset.
|
||||
new_dataset_root (str): The root directory where the new dataset will be written.
|
||||
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
Defaults to (128, 128).
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
|
||||
and resized.
|
||||
"""
|
||||
# 1. Create a new (empty) LeRobotDataset for writing.
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=new_repo_id,
|
||||
fps=original_dataset.fps,
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info["features"],
|
||||
use_videos=len(original_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Update the metadata for every image key that will be cropped:
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||
|
||||
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
|
||||
prev_episode_index = 0
|
||||
for frame_idx in tqdm(range(len(original_dataset))):
|
||||
frame = original_dataset[frame_idx]
|
||||
|
||||
# Create a copy of the frame to add to the new dataset
|
||||
new_frame = {}
|
||||
for key, value in frame.items():
|
||||
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index"):
|
||||
continue
|
||||
if key in ("next.done", "next.reward"):
|
||||
# if not isinstance(value, str) and len(value.shape) == 0:
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
if key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(value, top, left, height, width)
|
||||
value = F.resize(cropped, resize_size)
|
||||
value = value.clamp(0, 1)
|
||||
|
||||
new_frame[key] = value
|
||||
|
||||
new_dataset.add_frame(new_frame)
|
||||
|
||||
if frame["episode_index"].item() != prev_episode_index:
|
||||
# Save the episode
|
||||
new_dataset.save_episode()
|
||||
prev_episode_index = frame["episode_index"].item()
|
||||
|
||||
if push_to_hub:
|
||||
new_dataset.push_to_hub()
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="The repository id of the LeRobot dataset to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The root directory of the LeRobot dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-params-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the JSON file containing the ROIs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to push the new dataset to the hub.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path) as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# Print the selected rectangular ROIs
|
||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
for key, roi in rois.items():
|
||||
print(f"{key}: {roi}")
|
||||
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=new_repo_id,
|
||||
new_dataset_root=new_dataset_root,
|
||||
resize_size=(128, 128),
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
meta_dir = new_dataset_root / "meta"
|
||||
meta_dir.mkdir(exist_ok=True)
|
||||
|
||||
with open(meta_dir / "crop_params.json", "w") as f:
|
||||
json.dump(rois, f, indent=4)
|
||||
@@ -1,802 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
|
||||
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
|
||||
"""
|
||||
Initialize the controller.
|
||||
|
||||
Args:
|
||||
x_step_size: Base movement step size in meters
|
||||
y_step_size: Base movement step size in meters
|
||||
z_step_size: Base movement step size in meters
|
||||
"""
|
||||
self.x_step_size = x_step_size
|
||||
self.y_step_size = y_step_size
|
||||
self.z_step_size = z_step_size
|
||||
self.running = True
|
||||
self.episode_end_status = None # None, "success", or "failure"
|
||||
self.intervention_flag = False
|
||||
self.open_gripper_command = False
|
||||
self.close_gripper_command = False
|
||||
|
||||
def start(self):
|
||||
"""Start the controller and initialize resources."""
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and release resources."""
|
||||
pass
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if the user has requested to quit."""
|
||||
return not self.running
|
||||
|
||||
def update(self):
|
||||
"""Update controller state - call this once per frame."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for use in 'with' statements."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Ensure resources are released when exiting 'with' block."""
|
||||
self.stop()
|
||||
|
||||
def get_episode_end_status(self):
|
||||
"""
|
||||
Get the current episode end status.
|
||||
|
||||
Returns:
|
||||
None if episode should continue, "success" or "failure" otherwise
|
||||
"""
|
||||
status = self.episode_end_status
|
||||
self.episode_end_status = None # Reset after reading
|
||||
return status
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
def gripper_command(self):
|
||||
"""Return the current gripper command."""
|
||||
if self.open_gripper_command == self.close_gripper_command:
|
||||
return "no-op"
|
||||
elif self.open_gripper_command:
|
||||
return "open"
|
||||
elif self.close_gripper_command:
|
||||
return "close"
|
||||
|
||||
|
||||
class KeyboardController(InputController):
|
||||
"""Generate motion deltas from keyboard input."""
|
||||
|
||||
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.key_states = {
|
||||
"forward_x": False,
|
||||
"backward_x": False,
|
||||
"forward_y": False,
|
||||
"backward_y": False,
|
||||
"forward_z": False,
|
||||
"backward_z": False,
|
||||
"quit": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
}
|
||||
self.listener = None
|
||||
|
||||
def start(self):
|
||||
"""Start the keyboard listener."""
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = True
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = True
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = True
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["quit"] = True
|
||||
self.running = False
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = "success"
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = "failure"
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def on_release(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = False
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = False
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = False
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = False
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = False
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = False
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
|
||||
self.listener.start()
|
||||
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" Enter: End episode with SUCCESS")
|
||||
print(" Backspace: End episode with FAILURE")
|
||||
print(" ESC: Exit")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
if self.listener and self.listener.is_alive():
|
||||
self.listener.stop()
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from keyboard state."""
|
||||
delta_x = delta_y = delta_z = 0.0
|
||||
|
||||
if self.key_states["forward_x"]:
|
||||
delta_x += self.x_step_size
|
||||
if self.key_states["backward_x"]:
|
||||
delta_x -= self.x_step_size
|
||||
if self.key_states["forward_y"]:
|
||||
delta_y += self.y_step_size
|
||||
if self.key_states["backward_y"]:
|
||||
delta_y -= self.y_step_size
|
||||
if self.key_states["forward_z"]:
|
||||
delta_z += self.z_step_size
|
||||
if self.key_states["backward_z"]:
|
||||
delta_z -= self.z_step_size
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if ESC was pressed."""
|
||||
return self.key_states["quit"]
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if Enter was pressed (save episode)."""
|
||||
return self.key_states["success"] or self.key_states["failure"]
|
||||
|
||||
|
||||
class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
|
||||
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
self.intervention_flag = False
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
|
||||
if pygame.joystick.get_count() == 0:
|
||||
logging.error("No gamepad detected. Please connect a gamepad and try again.")
|
||||
self.running = False
|
||||
return
|
||||
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
||||
|
||||
print("Gamepad controls:")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick (vertical): Move in Z axis")
|
||||
print(" B/Circle button: Exit")
|
||||
print(" Y/Triangle button: End episode with SUCCESS")
|
||||
print(" A/Cross button: End episode with FAILURE")
|
||||
print(" X/Square button: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
if self.joystick:
|
||||
self.joystick.quit()
|
||||
pygame.joystick.quit()
|
||||
pygame.quit()
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = "success"
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = "failure"
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# LT button (7) for opening gripper
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [0, 2, 3]:
|
||||
self.episode_end_status = None
|
||||
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = False
|
||||
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = False
|
||||
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
self.intervention_flag = True
|
||||
else:
|
||||
self.intervention_flag = False
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
x_input = self.joystick.get_axis(0) # Left/Right
|
||||
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
|
||||
|
||||
# Right stick Y (typically axis 3 or 4)
|
||||
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
||||
|
||||
# Apply deadzone to avoid drift
|
||||
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||
|
||||
# Calculate deltas (note: may need to invert axes depending on controller)
|
||||
delta_x = -y_input * self.y_step_size # Forward/backward
|
||||
delta_y = -x_input * self.x_step_size # Left/right
|
||||
delta_z = -z_input * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
except pygame.error:
|
||||
logging.error("Error reading gamepad. Is it still connected?")
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
|
||||
class GamepadControllerHID(InputController):
|
||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_step_size=0.01,
|
||||
y_step_size=0.01,
|
||||
z_step_size=0.01,
|
||||
deadzone=0.1,
|
||||
vendor_id=0x046D,
|
||||
product_id=0xC219,
|
||||
):
|
||||
"""
|
||||
Initialize the HID gamepad controller.
|
||||
|
||||
Args:
|
||||
step_size: Base movement step size in meters
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
vendor_id: USB vendor ID of the gamepad (default: Logitech)
|
||||
product_id: USB product ID of the gamepad (default: RumblePad 2)
|
||||
"""
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.vendor_id = vendor_id
|
||||
self.product_id = product_id
|
||||
self.device = None
|
||||
self.device_info = None
|
||||
|
||||
# Movement values (normalized from -1.0 to 1.0)
|
||||
self.left_x = 0.0
|
||||
self.left_y = 0.0
|
||||
self.right_x = 0.0
|
||||
self.right_y = 0.0
|
||||
|
||||
# Button states
|
||||
self.buttons = {}
|
||||
self.quit_requested = False
|
||||
self.save_requested = False
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
if device["vendor_id"] == self.vendor_id and device["product_id"] == self.product_id:
|
||||
logging.info(f"Found gamepad: {device.get('product_string', 'Unknown')}")
|
||||
return device
|
||||
|
||||
logging.error(
|
||||
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and product ID 0x{self.product_id:04X} found"
|
||||
)
|
||||
return None
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
if not self.device_info:
|
||||
self.running = False
|
||||
return
|
||||
|
||||
try:
|
||||
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
|
||||
self.device = hid.device()
|
||||
self.device.open_path(self.device_info["path"])
|
||||
self.device.set_nonblocking(1)
|
||||
|
||||
manufacturer = self.device.get_manufacturer_string()
|
||||
product = self.device.get_product_string()
|
||||
logging.info(f"Connected to {manufacturer} {product}")
|
||||
|
||||
logging.info("Gamepad controls (HID mode):")
|
||||
logging.info(" Left analog stick: Move in X-Y plane")
|
||||
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
||||
logging.info(" Button 1/B/Circle: Exit")
|
||||
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
||||
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error opening gamepad: {e}")
|
||||
logging.error("You might need to run this with sudo/admin privileges on some systems")
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
"""Close the HID device connection."""
|
||||
if self.device:
|
||||
self.device.close()
|
||||
self.device = None
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Read and process the latest gamepad data.
|
||||
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
||||
"""
|
||||
for _ in range(10):
|
||||
self._update()
|
||||
|
||||
def _update(self):
|
||||
"""Read and process the latest gamepad data."""
|
||||
if not self.device or not self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Read data from the gamepad
|
||||
data = self.device.read(64)
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if data and len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
# Apply deadzone
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error reading from gamepad: {e}")
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
# Calculate deltas - invert as needed based on controller orientation
|
||||
delta_x = -self.left_y * self.x_step_size # Forward/backward
|
||||
delta_y = -self.left_x * self.y_step_size # Left/right
|
||||
delta_z = -self.right_y * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if quit button was pressed."""
|
||||
return self.quit_requested
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if save button was pressed."""
|
||||
return self.save_requested
|
||||
|
||||
|
||||
def test_forward_kinematics(robot, fps=10):
|
||||
logging.info("Testing Forward Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
logging.info(f"EE Position: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def test_inverse_kinematics(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
|
||||
desired_ee_pos = ee_pos
|
||||
target_joint_state = RobotKinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Target Joint State: {target_joint_state}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = leader_ee
|
||||
target_joint_state = kinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||
)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Delta End-Effector Control")
|
||||
timestep = time.perf_counter()
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = np.diag(np.ones(4))
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Get leader state for teleoperation
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
# Get current state
|
||||
# obs = robot.capture_observation()
|
||||
# joint_positions = obs["observation.state"].cpu().numpy()
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Calculate delta between leader and follower end-effectors
|
||||
# Scaling factor can be adjusted for sensitivity
|
||||
scaling_factor = 1.0
|
||||
ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05)
|
||||
|
||||
# Apply delta to current position
|
||||
desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0
|
||||
desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3]
|
||||
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||
)
|
||||
|
||||
initial_leader_ee = leader_ee.copy()
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
# Logging
|
||||
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
|
||||
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, fk_func=None):
|
||||
"""
|
||||
Control a robot using delta end-effector movements from any input controller.
|
||||
|
||||
Args:
|
||||
robot: Robot instance to control
|
||||
controller: InputController instance (keyboard, gamepad, etc.)
|
||||
fps: Control frequency in Hz
|
||||
bounds: Optional position limits
|
||||
fk_func: Forward kinematics function to use
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = RobotKinematics.fk_gripper_tip
|
||||
|
||||
logging.info(f"Testing Delta End-Effector Control with {controller.__class__.__name__}")
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Initialize desired position with current position
|
||||
desired_ee_pos = np.eye(4) # Identity matrix
|
||||
|
||||
timestep = time.perf_counter()
|
||||
with controller:
|
||||
while not controller.should_quit() and time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get current robot state
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
||||
# Update desired position
|
||||
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x
|
||||
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z
|
||||
|
||||
# Apply bounds if provided
|
||||
if bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"])
|
||||
|
||||
# Only send commands if there's actual movement
|
||||
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||
"""
|
||||
Control a robot through a gym environment using keyboard inputs.
|
||||
|
||||
Args:
|
||||
env: A gym environment created with make_robot_env
|
||||
fps: Target control frequency
|
||||
"""
|
||||
|
||||
logging.info("Testing Keyboard Control of Gym Environment")
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" ESC: Exit")
|
||||
|
||||
# Reset the environment to get initial observation
|
||||
obs, info = env.reset()
|
||||
|
||||
try:
|
||||
with controller:
|
||||
while not controller.should_quit():
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
||||
# Create the action vector
|
||||
action = np.array([delta_x, delta_y, delta_z])
|
||||
|
||||
# Skip if no movement
|
||||
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
|
||||
# Step the environment - pass action as a tensor with intervention flag
|
||||
action_tensor = torch.from_numpy(action.astype(np.float32))
|
||||
obs, reward, terminated, truncated, info = env.step((action_tensor, False))
|
||||
|
||||
# Log information
|
||||
logging.info(f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]")
|
||||
logging.info(f"Reward: {reward}")
|
||||
|
||||
# Reset if episode ended
|
||||
if terminated or truncated:
|
||||
logging.info("Episode ended, resetting environment")
|
||||
obs, info = env.reset()
|
||||
|
||||
# Maintain target frame rate
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
finally:
|
||||
# Close the environment
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvTransformConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="keyboard",
|
||||
choices=[
|
||||
"keyboard",
|
||||
"gamepad",
|
||||
"keyboard_gym",
|
||||
"gamepad_gym",
|
||||
"leader_delta",
|
||||
"leader",
|
||||
],
|
||||
help="Control mode to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-type",
|
||||
type=str,
|
||||
default="so100",
|
||||
help="Robot type (so100, koch, aloha, etc.)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
|
||||
robot = make_robot_from_config(robot_config)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
# Example bounds
|
||||
bounds = {
|
||||
"max": np.array([0.32170487, 0.201285, 0.10273342]),
|
||||
"min": np.array([0.16631757, -0.08237468, 0.03364977]),
|
||||
}
|
||||
|
||||
try:
|
||||
# Determine controller type based on mode prefix
|
||||
controller = None
|
||||
if args.mode.startswith("keyboard"):
|
||||
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
elif args.mode.startswith("gamepad"):
|
||||
if sys.platform == "darwin":
|
||||
controller = GamepadControllerHID(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
else:
|
||||
controller = GamepadController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
|
||||
# Handle mode categories
|
||||
if args.mode in ["keyboard", "gamepad"]:
|
||||
# Direct robot control modes
|
||||
teleoperate_delta_inverse_kinematics(robot, controller, bounds=bounds, fps=10)
|
||||
|
||||
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
||||
# Gym environment control modes
|
||||
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvTransformConfig())
|
||||
cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
|
||||
x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
|
||||
)
|
||||
cfg.wrapper.ee_action_space_params.use_gamepad = False
|
||||
cfg.device = "cpu"
|
||||
env = make_robot_env(cfg, robot)
|
||||
teleoperate_gym_env(env, controller, fps=cfg.fps)
|
||||
|
||||
elif args.mode == "leader_delta":
|
||||
# Leader-follower modes don't use controllers
|
||||
teleoperate_delta_inverse_kinematics_with_leader(robot)
|
||||
|
||||
elif args.mode == "leader":
|
||||
teleoperate_inverse_kinematics_with_leader(robot)
|
||||
|
||||
finally:
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
@@ -1,135 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.control_utils import is_headless
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.configs import parser
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
|
||||
def find_joint_bounds(
|
||||
robot,
|
||||
control_time_s=30,
|
||||
display_cameras=False,
|
||||
):
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
pos_list = []
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
# Wait for 5 seconds to stabilize the robot initial position
|
||||
if time.perf_counter() - start_episode_t < 5:
|
||||
continue
|
||||
|
||||
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
if time.perf_counter() - start_episode_t > control_time_s:
|
||||
max = np.max(np.stack(pos_list), 0)
|
||||
min = np.min(np.stack(pos_list), 0)
|
||||
print(f"Max angle position per joint {max}")
|
||||
print(f"Min angle position per joint {min}")
|
||||
break
|
||||
|
||||
|
||||
def find_ee_bounds(
|
||||
robot,
|
||||
control_time_s=30,
|
||||
display_cameras=False,
|
||||
):
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
ee_list = []
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
# Wait for 5 seconds to stabilize the robot initial position
|
||||
if time.perf_counter() - start_episode_t < 5:
|
||||
continue
|
||||
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
print(f"Joint positions: {joint_positions}")
|
||||
ee_list.append(kinematics.fk_gripper_tip(joint_positions)[:3, 3])
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
if time.perf_counter() - start_episode_t > control_time_s:
|
||||
max = np.max(np.stack(ee_list), 0)
|
||||
min = np.min(np.stack(ee_list), 0)
|
||||
print(f"Max ee position {max}")
|
||||
print(f"Min ee position {min}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create argparse for script-specific arguments
|
||||
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="joint",
|
||||
choices=["joint", "ee"],
|
||||
help="Mode to run the script in. Can be 'joint' or 'ee'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control-time-s",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Time step to use for control.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-type",
|
||||
type=str,
|
||||
default="so100",
|
||||
help="Robot type (so100, koch, aloha, etc.)",
|
||||
)
|
||||
|
||||
# Only parse known args, leaving robot config args for Hydra if used
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create robot with the appropriate config
|
||||
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
|
||||
robot = make_robot_from_config(robot_config)
|
||||
|
||||
if args.mode == "joint":
|
||||
find_joint_bounds(robot, args.control_time_s)
|
||||
elif args.mode == "ee":
|
||||
find_ee_bounds(robot, args.control_time_s)
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,54 +0,0 @@
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package hil_serl;
|
||||
|
||||
// LearnerService: the Actor calls this to push transitions.
|
||||
// The Learner implements this service.
|
||||
service LearnerService {
|
||||
// Actor -> Learner to store transitions
|
||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||
rpc StreamParameters(Empty) returns (stream Parameters);
|
||||
rpc SendTransitions(stream Transition) returns (Empty);
|
||||
rpc SendInteractions(stream InteractionMessage) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Transition {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Parameters {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message InteractionMessage {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
@@ -1,46 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: hilserl.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'hilserl.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=275
|
||||
_globals['_TRANSFERSTATE']._serialized_end=371
|
||||
_globals['_TRANSITION']._serialized_start=27
|
||||
_globals['_TRANSITION']._serialized_end=102
|
||||
_globals['_PARAMETERS']._serialized_start=104
|
||||
_globals['_PARAMETERS']._serialized_end=179
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=181
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=264
|
||||
_globals['_EMPTY']._serialized_start=266
|
||||
_globals['_EMPTY']._serialized_end=273
|
||||
_globals['_LEARNERSERVICE']._serialized_start=374
|
||||
_globals['_LEARNERSERVICE']._serialized_end=696
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,276 +0,0 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import hilserl_pb2 as hilserl__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.70.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in hilserl_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class LearnerServiceStub(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendInteractionMessage = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.StreamParameters = channel.unary_stream(
|
||||
'/hil_serl.LearnerService/StreamParameters',
|
||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Parameters.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTransitions = channel.stream_unary(
|
||||
'/hil_serl.LearnerService/SendTransitions',
|
||||
request_serializer=hilserl__pb2.Transition.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendInteractions = channel.stream_unary(
|
||||
'/hil_serl.LearnerService/SendInteractions',
|
||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/Ready',
|
||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class LearnerServiceServicer(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def SendInteractionMessage(self, request, context):
|
||||
"""Actor -> Learner to store transitions
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def StreamParameters(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendTransitions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendInteractions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_LearnerServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendInteractionMessage,
|
||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'StreamParameters': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamParameters,
|
||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.Parameters.SerializeToString,
|
||||
),
|
||||
'SendTransitions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendTransitions,
|
||||
request_deserializer=hilserl__pb2.Transition.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendInteractions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendInteractions,
|
||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'hil_serl.LearnerService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class LearnerService(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendInteractionMessage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
||||
hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def StreamParameters(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/StreamParameters',
|
||||
hilserl__pb2.Empty.SerializeToString,
|
||||
hilserl__pb2.Parameters.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendTransitions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendTransitions',
|
||||
hilserl__pb2.Transition.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendInteractions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendInteractions',
|
||||
hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/Ready',
|
||||
hilserl__pb2.Empty.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -1,546 +0,0 @@
|
||||
# ruff: noqa: N806, N815, N803
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
|
||||
def skew_symmetric(w):
|
||||
"""Creates the skew-symmetric matrix from a 3D vector."""
|
||||
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
|
||||
|
||||
|
||||
def rodrigues_rotation(w, theta):
|
||||
"""Computes the rotation matrix using Rodrigues' formula."""
|
||||
w_hat = skew_symmetric(w)
|
||||
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
|
||||
|
||||
def screw_axis_to_transform(S, theta):
|
||||
"""Converts a screw axis to a 4x4 transformation matrix."""
|
||||
S_w = S[:3]
|
||||
S_v = S[3:]
|
||||
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
|
||||
T = np.eye(4)
|
||||
T[:3, 3] = S_v * theta
|
||||
elif np.linalg.norm(S_w) == 1: # Rotation and translation
|
||||
w_hat = skew_symmetric(S_w)
|
||||
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v
|
||||
T = np.eye(4)
|
||||
T[:3, :3] = R
|
||||
T[:3, 3] = t
|
||||
else:
|
||||
raise ValueError("Invalid screw axis parameters")
|
||||
return T
|
||||
|
||||
|
||||
def pose_difference_se3(pose1, pose2):
|
||||
"""
|
||||
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
|
||||
SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, combining rotation (SO(3)) and translation.
|
||||
Each 4x4 matrix has the following structure, a 3x3 rotation matrix in the top-left and a 3x1 translation vector in the top-right:
|
||||
|
||||
[R11 R12 R13 tx]
|
||||
[R21 R22 R23 ty]
|
||||
[R31 R32 R33 tz]
|
||||
[ 0 0 0 1]
|
||||
|
||||
where Rij is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector.
|
||||
|
||||
pose1 - pose2
|
||||
|
||||
Args:
|
||||
pose1: A 4x4 numpy array representing the first pose.
|
||||
pose2: A 4x4 numpy array representing the second pose.
|
||||
|
||||
Returns:
|
||||
A tuple (translation_diff, rotation_diff) where:
|
||||
- translation_diff is a 3x1 numpy array representing the translational difference.
|
||||
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
|
||||
"""
|
||||
|
||||
# Extract rotation matrices from poses
|
||||
R1 = pose1[:3, :3]
|
||||
R2 = pose2[:3, :3]
|
||||
|
||||
# Calculate translational difference
|
||||
translation_diff = pose1[:3, 3] - pose2[:3, 3]
|
||||
|
||||
# Calculate rotational difference using scipy's Rotation library
|
||||
R_diff = Rotation.from_matrix(R1 @ R2.T)
|
||||
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
|
||||
|
||||
return np.concatenate([translation_diff, rotation_diff])
|
||||
|
||||
|
||||
def se3_error(target_pose, current_pose):
|
||||
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
|
||||
R_target = target_pose[:3, :3]
|
||||
R_current = current_pose[:3, :3]
|
||||
R_error = R_target @ R_current.T
|
||||
rot_error = Rotation.from_matrix(R_error).as_rotvec()
|
||||
return np.concatenate([pos_error, rot_error])
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
"""Robot kinematics class supporting multiple robot models."""
|
||||
|
||||
# Robot measurements dictionary
|
||||
ROBOT_MEASUREMENTS = {
|
||||
"koch": {
|
||||
"gripper": [0.239, -0.001, 0.024],
|
||||
"wrist": [0.209, 0, 0.024],
|
||||
"forearm": [0.108, 0, 0.02],
|
||||
"humerus": [0, 0, 0.036],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"so100": {
|
||||
"gripper": [0.320, 0, 0.050],
|
||||
"wrist": [0.278, 0, 0.050],
|
||||
"forearm": [0.143, 0, 0.044],
|
||||
"humerus": [0.031, 0, 0.072],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"moss": {
|
||||
"gripper": [0.246, 0.013, 0.111],
|
||||
"wrist": [0.245, 0.002, 0.064],
|
||||
"forearm": [0.122, 0, 0.064],
|
||||
"humerus": [0.001, 0.001, 0.063],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, robot_type="so100"):
|
||||
"""Initialize kinematics for the specified robot type.
|
||||
|
||||
Args:
|
||||
robot_type: String specifying the robot model ("koch", "so100", or "moss")
|
||||
"""
|
||||
if robot_type not in self.ROBOT_MEASUREMENTS:
|
||||
raise ValueError(
|
||||
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
|
||||
)
|
||||
|
||||
self.robot_type = robot_type
|
||||
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
|
||||
|
||||
# Initialize all transformation matrices and screw axes
|
||||
self._setup_transforms()
|
||||
|
||||
def _create_translation_matrix(self, x=0, y=0, z=0):
|
||||
"""Create a 4x4 translation matrix."""
|
||||
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
|
||||
|
||||
def _setup_transforms(self):
|
||||
"""Setup all transformation matrices and screw axes for the robot."""
|
||||
# Set up rotation matrices (constant across robot types)
|
||||
|
||||
# Gripper orientation
|
||||
self.gripper_X0 = np.array(
|
||||
[
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Wrist orientation
|
||||
self.wrist_X0 = np.array(
|
||||
[
|
||||
[0, -1, 0, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Base orientation
|
||||
self.base_X0 = np.array(
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Gripper
|
||||
# Screw axis of gripper frame wrt base frame
|
||||
self.S_BG = np.array(
|
||||
[
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
self.measurements["gripper"][2],
|
||||
-self.measurements["gripper"][1],
|
||||
]
|
||||
)
|
||||
|
||||
# Gripper origin to centroid transform
|
||||
self.X_GoGc = self._create_translation_matrix(x=0.07)
|
||||
|
||||
# Gripper origin to tip transform
|
||||
self.X_GoGt = self._create_translation_matrix(x=0.12)
|
||||
|
||||
# 0-position gripper frame pose wrt base
|
||||
self.X_BoGo = self._create_translation_matrix(
|
||||
x=self.measurements["gripper"][0],
|
||||
y=self.measurements["gripper"][1],
|
||||
z=self.measurements["gripper"][2],
|
||||
)
|
||||
|
||||
# Wrist
|
||||
# Screw axis of wrist frame wrt base frame
|
||||
self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]])
|
||||
|
||||
# 0-position origin to centroid transform
|
||||
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
|
||||
|
||||
# 0-position wrist frame pose wrt base
|
||||
self.X_BR = self._create_translation_matrix(
|
||||
x=self.measurements["wrist"][0],
|
||||
y=self.measurements["wrist"][1],
|
||||
z=self.measurements["wrist"][2],
|
||||
)
|
||||
|
||||
# Forearm
|
||||
# Screw axis of forearm frame wrt base frame
|
||||
self.S_BF = np.array(
|
||||
[
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
-self.measurements["forearm"][2],
|
||||
0,
|
||||
self.measurements["forearm"][0],
|
||||
]
|
||||
)
|
||||
|
||||
# Forearm origin + centroid transform
|
||||
self.X_FoFc = self._create_translation_matrix(x=0.036) # spellchecker:disable-line
|
||||
|
||||
# 0-position forearm frame pose wrt base
|
||||
self.X_BF = self._create_translation_matrix(
|
||||
x=self.measurements["forearm"][0],
|
||||
y=self.measurements["forearm"][1],
|
||||
z=self.measurements["forearm"][2],
|
||||
)
|
||||
|
||||
# Humerus
|
||||
# Screw axis of humerus frame wrt base frame
|
||||
self.S_BH = np.array(
|
||||
[
|
||||
0,
|
||||
-1,
|
||||
0,
|
||||
self.measurements["humerus"][2],
|
||||
0,
|
||||
-self.measurements["humerus"][0],
|
||||
]
|
||||
)
|
||||
|
||||
# Humerus origin to centroid transform
|
||||
self.X_HoHc = self._create_translation_matrix(x=0.0475)
|
||||
|
||||
# 0-position humerus frame pose wrt base
|
||||
self.X_BH = self._create_translation_matrix(
|
||||
x=self.measurements["humerus"][0],
|
||||
y=self.measurements["humerus"][1],
|
||||
z=self.measurements["humerus"][2],
|
||||
)
|
||||
|
||||
# Shoulder
|
||||
# Screw axis of shoulder frame wrt Base frame
|
||||
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
|
||||
|
||||
# Shoulder origin to centroid transform
|
||||
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
|
||||
|
||||
# 0-position shoulder frame pose wrt base
|
||||
self.X_BS = self._create_translation_matrix(
|
||||
x=self.measurements["shoulder"][0],
|
||||
y=self.measurements["shoulder"][1],
|
||||
z=self.measurements["shoulder"][2],
|
||||
)
|
||||
|
||||
# Base
|
||||
# Base origin to centroid transform
|
||||
self.X_BoBc = self._create_translation_matrix(y=0.015)
|
||||
|
||||
# World to base transform
|
||||
self.X_WoBo = self._create_translation_matrix(
|
||||
x=self.measurements["base"][0],
|
||||
y=self.measurements["base"][1],
|
||||
z=self.measurements["base"][2],
|
||||
)
|
||||
|
||||
# Pre-compute gripper post-multiplication matrix
|
||||
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
|
||||
|
||||
def fk_base(self):
|
||||
"""Forward kinematics for the base frame."""
|
||||
return self.X_WoBo @ self.X_BoBc @ self.base_X0
|
||||
|
||||
def fk_shoulder(self, robot_pos_deg):
|
||||
"""Forward kinematics for the shoulder frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS
|
||||
|
||||
def fk_humerus(self, robot_pos_deg):
|
||||
"""Forward kinematics for the humerus frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ self.X_HoHc
|
||||
@ self.X_BH
|
||||
)
|
||||
|
||||
def fk_forearm(self, robot_pos_deg):
|
||||
"""Forward kinematics for the forearm frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ self.X_FoFc # spellchecker:disable-line
|
||||
@ self.X_BF
|
||||
)
|
||||
|
||||
def fk_wrist(self, robot_pos_deg):
|
||||
"""Forward kinematics for the wrist frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||
@ self.X_RoRc
|
||||
@ self.X_BR
|
||||
@ self.wrist_X0
|
||||
)
|
||||
|
||||
def fk_gripper(self, robot_pos_deg):
|
||||
"""Forward kinematics for the gripper frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
|
||||
@ self._fk_gripper_post
|
||||
)
|
||||
|
||||
def fk_gripper_tip(self, robot_pos_deg):
|
||||
"""Forward kinematics for the gripper tip frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
|
||||
@ self.X_GoGt
|
||||
@ self.X_BoGo
|
||||
@ self.gripper_X0
|
||||
)
|
||||
|
||||
def compute_jacobian(self, robot_pos_deg, fk_func=None):
|
||||
"""Finite differences to compute the Jacobian.
|
||||
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
|
||||
in the jth joint's velocity.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Current joint positions in degrees
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
eps = 1e-8
|
||||
jac = np.zeros(shape=(6, 5))
|
||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
Sdot = (
|
||||
pose_difference_se3(
|
||||
fk_func(robot_pos_deg[:-1] + delta),
|
||||
fk_func(robot_pos_deg[:-1] - delta),
|
||||
)
|
||||
/ eps
|
||||
)
|
||||
jac[:, el_ix] = Sdot
|
||||
return jac
|
||||
|
||||
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
|
||||
"""Finite differences to compute the positional Jacobian.
|
||||
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
|
||||
in the jth joint's velocity.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Current joint positions in degrees
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
eps = 1e-8
|
||||
jac = np.zeros(shape=(3, 5))
|
||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
Sdot = (
|
||||
fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
|
||||
) / eps
|
||||
jac[:, el_ix] = Sdot
|
||||
return jac
|
||||
|
||||
def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None):
|
||||
"""Inverse kinematics using gradient descent.
|
||||
|
||||
Args:
|
||||
current_joint_state: Initial joint positions in degrees
|
||||
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
|
||||
position_only: If True, only match end-effector position, not orientation
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
|
||||
Returns:
|
||||
Joint positions in degrees that achieve the desired end-effector pose
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
# Do gradient descent.
|
||||
max_iterations = 5
|
||||
learning_rate = 1
|
||||
for _ in range(max_iterations):
|
||||
current_ee_pose = fk_func(current_joint_state)
|
||||
if not position_only:
|
||||
error = se3_error(desired_ee_pose, current_ee_pose)
|
||||
jac = self.compute_jacobian(current_joint_state, fk_func)
|
||||
else:
|
||||
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
|
||||
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
|
||||
delta_angles = np.linalg.pinv(jac) @ error
|
||||
current_joint_state[:-1] += learning_rate * delta_angles
|
||||
|
||||
if np.linalg.norm(error) < 5e-3:
|
||||
return current_joint_state
|
||||
return current_joint_state
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
def run_test(robot_type):
|
||||
"""Run test suite for a specific robot type."""
|
||||
print(f"\n--- Testing {robot_type.upper()} Robot ---")
|
||||
|
||||
# Initialize kinematics for this robot
|
||||
robot = RobotKinematics(robot_type)
|
||||
|
||||
# Test 1: Forward kinematics consistency
|
||||
print("Test 1: Forward kinematics consistency")
|
||||
test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees
|
||||
|
||||
# Calculate FK for different joints
|
||||
shoulder_pose = robot.fk_shoulder(test_angles)
|
||||
humerus_pose = robot.fk_humerus(test_angles)
|
||||
forearm_pose = robot.fk_forearm(test_angles)
|
||||
wrist_pose = robot.fk_wrist(test_angles)
|
||||
gripper_pose = robot.fk_gripper(test_angles)
|
||||
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
|
||||
|
||||
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
|
||||
distances = [
|
||||
np.linalg.norm(shoulder_pose[:3, 3]),
|
||||
np.linalg.norm(humerus_pose[:3, 3]),
|
||||
np.linalg.norm(forearm_pose[:3, 3]),
|
||||
np.linalg.norm(wrist_pose[:3, 3]),
|
||||
np.linalg.norm(gripper_pose[:3, 3]),
|
||||
np.linalg.norm(gripper_tip_pose[:3, 3]),
|
||||
]
|
||||
|
||||
# Check if distances generally increase along the chain
|
||||
is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1))
|
||||
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
|
||||
print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}")
|
||||
|
||||
# Test 2: Jacobian computation
|
||||
print("Test 2: Jacobian computation")
|
||||
jacobian = robot.compute_jacobian(test_angles)
|
||||
positional_jacobian = robot.compute_positional_jacobian(test_angles)
|
||||
|
||||
# Check shapes
|
||||
jacobian_shape_ok = jacobian.shape == (6, 5)
|
||||
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
|
||||
|
||||
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
|
||||
print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}")
|
||||
|
||||
# Test 3: Inverse kinematics
|
||||
print("Test 3: Inverse kinematics (position only)")
|
||||
|
||||
# Generate target pose from known joint angles
|
||||
original_angles = np.array([10, 20, 30, -10, 5, 0])
|
||||
target_pose = robot.fk_gripper(original_angles)
|
||||
|
||||
# Start IK from a different position
|
||||
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
# Measure IK performance
|
||||
start_time = time.time()
|
||||
computed_angles = robot.ik(initial_guess.copy(), target_pose)
|
||||
ik_time = time.time() - start_time
|
||||
|
||||
# Compute resulting pose from IK solution
|
||||
result_pose = robot.fk_gripper(computed_angles)
|
||||
|
||||
# Calculate position error
|
||||
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
|
||||
passed = pos_error < 0.01 # Accept errors less than 1cm
|
||||
|
||||
print(f" IK computation time: {ik_time:.4f} seconds")
|
||||
print(f" Position error: {pos_error:.4f}")
|
||||
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
|
||||
|
||||
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
|
||||
|
||||
# Run tests for all robot types
|
||||
results = {}
|
||||
for robot_type in ["koch", "so100", "moss"]:
|
||||
results[robot_type] = run_test(robot_type)
|
||||
|
||||
# Print overall summary
|
||||
print("\n=== Test Summary ===")
|
||||
all_passed = all(results.values())
|
||||
for robot_type, passed in results.items():
|
||||
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
|
||||
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,80 +0,0 @@
|
||||
import logging
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
|
||||
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
SHUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
def __init__(
|
||||
self,
|
||||
shutdown_event: Event, # type: ignore
|
||||
parameters_queue: Queue,
|
||||
seconds_between_pushes: float,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
):
|
||||
self.shutdown_event = shutdown_event
|
||||
self.parameters_queue = parameters_queue
|
||||
self.seconds_between_pushes = seconds_between_pushes
|
||||
self.transition_queue = transition_queue
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
|
||||
def StreamParameters(self, request, context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Push parameters to the Actor")
|
||||
buffer = self.parameters_queue.get()
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
buffer,
|
||||
hilserl_pb2.Parameters,
|
||||
log_prefix="[LEARNER] Sending parameters",
|
||||
silent=True,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] Parameters sent")
|
||||
|
||||
self.shutdown_event.wait(self.seconds_between_pushes)
|
||||
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.transition_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] transitions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.interaction_message_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] interactions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
return hilserl_pb2.Empty()
|
||||
@@ -1,142 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import pickle # nosec B403: Safe usage for internal serialization only
|
||||
from multiprocessing import Event, Queue
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.scripts.server import hilserl_pb2
|
||||
from lerobot.scripts.server.utils import Transition
|
||||
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
|
||||
|
||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
buffer.seek(0, io.SEEK_END)
|
||||
result = buffer.tell()
|
||||
buffer.seek(0)
|
||||
return result
|
||||
|
||||
|
||||
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging_method = logging.info if not silent else logging.debug
|
||||
|
||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
|
||||
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
logging.info(f"{log_prefix} Starting receiver")
|
||||
for item in iterator:
|
||||
logging.debug(f"{log_prefix} Received item")
|
||||
if shutdown_event.is_set():
|
||||
logging.info(f"{log_prefix} Shutting down receiver")
|
||||
return
|
||||
|
||||
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step 0")
|
||||
step = 0
|
||||
continue
|
||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
|
||||
bytes_buffer.write(item.data)
|
||||
step += 1
|
||||
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||
|
||||
queue.put(bytes_buffer.getvalue())
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
step = 0
|
||||
|
||||
logging.debug(f"{log_prefix} Queue updated")
|
||||
|
||||
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
"""Convert model state dict to flat array for transmission"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
torch.save(state_dict, buffer)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer, weights_only=False) # nosec B614: Using weights_only=False relies on pickle which has security implications.
|
||||
# This is currently safe as we only deserialize trusted internal data.
|
||||
# TODO: Verify if weights_only=True would work for our use case (safer default in torch 2.6+)
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
return pickle.dumps(python_object)
|
||||
|
||||
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
||||
# Add validation checks here
|
||||
return obj
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
transitions = torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
|
||||
# Add validation checks here
|
||||
return transitions
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(transitions, buffer)
|
||||
return buffer.getvalue()
|
||||
@@ -1,141 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from queue import Empty
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
shutdown_event_counter = 0
|
||||
|
||||
|
||||
def setup_process_handlers(use_threads: bool) -> any:
|
||||
if use_threads:
|
||||
from threading import Event
|
||||
else:
|
||||
from multiprocessing import Event
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Define signal handler
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
global shutdown_event_counter
|
||||
shutdown_event_counter += 1
|
||||
|
||||
if shutdown_event_counter > 1:
|
||||
logging.info("Force shutdown")
|
||||
sys.exit(1)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
|
||||
return shutdown_event
|
||||
|
||||
|
||||
def get_last_item_from_queue(queue: Queue):
|
||||
item = queue.get()
|
||||
counter = 1
|
||||
|
||||
# Drain queue and keep only the most recent parameters
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
counter += 1
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
logging.debug(f"Drained {counter} items from queue")
|
||||
|
||||
return item
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
truncated: bool
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
device = torch.device(device)
|
||||
non_blocking = device.type == "cuda"
|
||||
|
||||
# Move state tensors to device
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to device
|
||||
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move reward and done if they are tensors
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move next_state tensors to device
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# Move complementary_info tensors if present
|
||||
if transition.get("complementary_info") is not None:
|
||||
for key, val in transition["complementary_info"].items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device="cpu"):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
"""
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
||||
else:
|
||||
return state_dict
|
||||
@@ -63,13 +63,13 @@ dependencies = [
|
||||
"opencv-python-headless>=4.9.0",
|
||||
"packaging>=24.2",
|
||||
"av>=14.2.0",
|
||||
"pymunk>=6.6.0",
|
||||
"pymunk>=6.6.0,<7.0.0",
|
||||
"pynput>=1.7.7",
|
||||
"pyzmq>=26.2.1",
|
||||
"rerun-sdk>=0.21.0",
|
||||
"termcolor>=2.4.0",
|
||||
"torch>=2.2.1,<2.7",
|
||||
"torchcodec==0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torch>=2.2.1",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchvision>=0.21.0",
|
||||
"wandb>=0.16.3",
|
||||
"zarr>=2.17.0",
|
||||
@@ -84,7 +84,6 @@ dora = [
|
||||
]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
hilserl = ["transformers>=4.48", "gym-hil>=0.1.3", "protobuf>=5.29.3", "grpcio>=1.70.0"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
@@ -99,40 +98,13 @@ umi = ["imagecodecs>=2024.1.1"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
||||
|
||||
|
||||
[tool.poetry]
|
||||
requires-poetry = ">=2.1"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
"tests/data",
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
"*_pb2.py",
|
||||
"*_pb2_grpc.py",
|
||||
]
|
||||
|
||||
exclude = ["tests/artifacts/**/*.safetensors"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0389a716d51c1c615fb2a3bfa386d89f00b0deca08c4fa21b23e020a939d0213
|
||||
oid sha256:6b1e600768a8771c5fe650e038a1193597e3810f032041b2a0d021e4496381c1
|
||||
size 3686488
|
||||
|
||||
@@ -28,7 +28,7 @@ from lerobot.common.datasets.transforms import (
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
|
||||
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
DATASET_REPO_ID = "lerobot/aloha_static_cups_open"
|
||||
|
||||
|
||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0dc691503e7d90b2086bb408e89a65f772ce5ee6e3562ef8c127bcb09bd90851
|
||||
oid sha256:9d4ebab73eabddc58879a4e770289d19e00a1a4cf2fa5fa33cd3a3246992bc90
|
||||
size 40551392
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
|
||||
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
|
||||
size 5104
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
|
||||
oid sha256:1a7a8b1a457149109f843c32bcbb047d09de2201847b9b79f7501b447f77ecf4
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
|
||||
oid sha256:5e6ce85296b2009e7c2060d336c0429b1c7197d9adb159e7df0ba18003067b36
|
||||
size 68
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
|
||||
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
|
||||
size 33400
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
|
||||
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
|
||||
size 515400
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
|
||||
oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
|
||||
oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd
|
||||
size 68
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
|
||||
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
|
||||
size 33400
|
||||
|
||||
@@ -14,11 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
@@ -88,19 +86,3 @@ def patch_builtins_input(monkeypatch):
|
||||
print(text)
|
||||
|
||||
monkeypatch.setattr("builtins.input", print_text)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--seed",
|
||||
action="store",
|
||||
default="42",
|
||||
help="Set random seed for reproducibility",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed(request):
|
||||
seed = int(request.config.getoption("--seed"))
|
||||
random.seed(seed) # Python random
|
||||
torch.manual_seed(seed) # PyTorch
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
@@ -253,7 +254,14 @@ def test_backward_compatibility_single_transforms(
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
@pytest.mark.skipif(
|
||||
version.parse(torch.__version__) < version.parse("2.7.0"),
|
||||
reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior",
|
||||
)
|
||||
def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
||||
# NOTE: PyTorch versions have different randomness, it might break this test.
|
||||
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
|
||||
|
||||
cfg = ImageTransformsConfig(enable=True)
|
||||
default_tf = ImageTransforms(cfg)
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from lerobot.common.constants import (
|
||||
from lerobot.common.optim.optimizers import (
|
||||
AdamConfig,
|
||||
AdamWConfig,
|
||||
MultiAdamConfig,
|
||||
SGDConfig,
|
||||
load_optimizer_state,
|
||||
save_optimizer_state,
|
||||
@@ -34,21 +33,13 @@ from lerobot.common.optim.optimizers import (
|
||||
(AdamConfig, torch.optim.Adam),
|
||||
(AdamWConfig, torch.optim.AdamW),
|
||||
(SGDConfig, torch.optim.SGD),
|
||||
(MultiAdamConfig, dict),
|
||||
],
|
||||
)
|
||||
def test_optimizer_build(config_cls, expected_class, model_params):
|
||||
config = config_cls()
|
||||
if config_cls == MultiAdamConfig:
|
||||
params_dict = {"default": model_params}
|
||||
optimizer = config.build(params_dict)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert isinstance(optimizer["default"], torch.optim.Adam)
|
||||
assert optimizer["default"].defaults["lr"] == config.lr
|
||||
else:
|
||||
optimizer = config.build(model_params)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert optimizer.defaults["lr"] == config.lr
|
||||
optimizer = config.build(model_params)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert optimizer.defaults["lr"] == config.lr
|
||||
|
||||
|
||||
def test_save_optimizer_state(optimizer, tmp_path):
|
||||
@@ -63,180 +54,3 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
|
||||
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
|
||||
|
||||
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_params_dict():
|
||||
return {
|
||||
"actor": [torch.nn.Parameter(torch.randn(10, 10))],
|
||||
"critic": [torch.nn.Parameter(torch.randn(5, 5))],
|
||||
"temperature": [torch.nn.Parameter(torch.randn(3, 3))],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_params, expected_values",
|
||||
[
|
||||
# Test 1: Basic configuration with different learning rates
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
},
|
||||
),
|
||||
# Test 2: Different weight decays and beta values
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-5},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-6},
|
||||
"temperature": {"lr": 2e-3, "betas": (0.95, 0.999)},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)},
|
||||
},
|
||||
),
|
||||
# Test 3: Epsilon parameter customization
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4, "eps": 1e-6},
|
||||
"critic": {"lr": 5e-4, "eps": 1e-7},
|
||||
"temperature": {"lr": 2e-3, "eps": 1e-8},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multi_adam_configuration(base_params_dict, config_params, expected_values):
|
||||
# Create config with the given parameters
|
||||
config = MultiAdamConfig(**config_params)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
# Verify optimizer count and keys
|
||||
assert len(optimizers) == len(expected_values)
|
||||
assert set(optimizers.keys()) == set(expected_values.keys())
|
||||
|
||||
# Check that all optimizers are Adam instances
|
||||
for opt in optimizers.values():
|
||||
assert isinstance(opt, torch.optim.Adam)
|
||||
|
||||
# Verify hyperparameters for each optimizer
|
||||
for name, expected in expected_values.items():
|
||||
optimizer = optimizers[name]
|
||||
for param, value in expected.items():
|
||||
assert optimizer.defaults[param] == value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_optimizers(base_params_dict):
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
return config.build(base_params_dict)
|
||||
|
||||
|
||||
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
# Verify that directories were created for each optimizer
|
||||
for name in multi_optimizers:
|
||||
assert (tmp_path / name).is_dir()
|
||||
assert (tmp_path / name / OPTIMIZER_STATE).is_file()
|
||||
assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file()
|
||||
|
||||
|
||||
def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path):
|
||||
# Option 1: Add a minimal backward pass to populate optimizer states
|
||||
for name, params in base_params_dict.items():
|
||||
if name in multi_optimizers:
|
||||
# Create a dummy loss and do backward
|
||||
dummy_loss = params[0].sum()
|
||||
dummy_loss.backward()
|
||||
# Perform an optimization step
|
||||
multi_optimizers[name].step()
|
||||
# Zero gradients for next steps
|
||||
multi_optimizers[name].zero_grad()
|
||||
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
# Create new optimizers with the same config
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
# Verify state dictionaries match
|
||||
for name in multi_optimizers:
|
||||
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
|
||||
|
||||
|
||||
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||
"""Test saving and loading optimizer states even when the state is empty (no backward pass)."""
|
||||
# Create config and build optimizers
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
# Save optimizer states without any backward pass (empty state)
|
||||
save_optimizer_state(optimizers, tmp_path)
|
||||
|
||||
# Create new optimizers with the same config
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
# Verify hyperparameters match even with empty state
|
||||
for name, optimizer in optimizers.items():
|
||||
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
|
||||
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
|
||||
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
|
||||
|
||||
# Verify state dictionaries match (they will be empty)
|
||||
torch.testing.assert_close(
|
||||
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
|
||||
)
|
||||
|
||||
@@ -37,7 +37,6 @@ def test_diffuser_scheduler(optimizer):
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
@@ -56,7 +55,6 @@ def test_vqbet_scheduler(optimizer):
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
@@ -77,7 +75,6 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def test_classifier_output():
|
||||
output = ClassifierOutput(
|
||||
logits=torch.tensor([1, 2, 3]),
|
||||
probabilities=torch.tensor([0.1, 0.2, 0.3]),
|
||||
hidden_states=None,
|
||||
)
|
||||
|
||||
assert (
|
||||
f"{output}"
|
||||
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
|
||||
)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
config.num_cameras = 1
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
assert len(images) == 1
|
||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||
assert labels.shape == torch.Size([batch_size])
|
||||
|
||||
output = classifier.predict(images)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.size() == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||
}
|
||||
config.num_cameras = 1
|
||||
config.num_classes = num_classes
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.rand((batch_size, num_classes)),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
assert len(images) == 1
|
||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||
assert labels.shape == torch.Size([batch_size, num_classes])
|
||||
|
||||
output = classifier.predict(images)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig(device="cpu")
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
@@ -20,6 +20,7 @@ from pathlib import Path
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot import available_policies
|
||||
@@ -408,7 +409,16 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
4. Check that this test now passes.
|
||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
||||
|
||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
||||
https://github.com/huggingface/lerobot/pull/1127.
|
||||
|
||||
"""
|
||||
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
|
||||
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
|
||||
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
|
||||
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import (
|
||||
ActorLearnerConfig,
|
||||
ActorNetworkConfig,
|
||||
ConcurrencyConfig,
|
||||
CriticNetworkConfig,
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
config = SACConfig()
|
||||
|
||||
assert config.normalization_mapping == {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
assert config.dataset_stats == {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
|
||||
# Basic parameters
|
||||
assert config.device == "cpu"
|
||||
assert config.storage_device == "cpu"
|
||||
assert config.discount == 0.99
|
||||
assert config.temperature_init == 1.0
|
||||
assert config.num_critics == 2
|
||||
|
||||
# Architecture specifics
|
||||
assert config.vision_encoder_name is None
|
||||
assert config.freeze_vision_encoder is True
|
||||
assert config.image_encoder_hidden_dim == 32
|
||||
assert config.shared_encoder is True
|
||||
assert config.num_discrete_actions is None
|
||||
assert config.image_embedding_pooling_dim == 8
|
||||
|
||||
# Training parameters
|
||||
assert config.online_steps == 1000000
|
||||
assert config.online_env_seed == 10000
|
||||
assert config.online_buffer_capacity == 100000
|
||||
assert config.offline_buffer_capacity == 100000
|
||||
assert config.async_prefetch is False
|
||||
assert config.online_step_before_learning == 100
|
||||
assert config.policy_update_freq == 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
assert config.num_subsample_critics is None
|
||||
assert config.critic_lr == 3e-4
|
||||
assert config.actor_lr == 3e-4
|
||||
assert config.temperature_lr == 3e-4
|
||||
assert config.critic_target_update_weight == 0.005
|
||||
assert config.utd_ratio == 1
|
||||
assert config.state_encoder_hidden_dim == 256
|
||||
assert config.latent_dim == 256
|
||||
assert config.target_entropy is None
|
||||
assert config.use_backup_entropy is True
|
||||
assert config.grad_clip_norm == 40.0
|
||||
|
||||
# Dataset stats defaults
|
||||
expected_dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
assert config.dataset_stats == expected_dataset_stats
|
||||
|
||||
# Critic network configuration
|
||||
assert config.critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.critic_network_kwargs.activate_final is True
|
||||
assert config.critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor network configuration
|
||||
assert config.actor_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.actor_network_kwargs.activate_final is True
|
||||
|
||||
# Policy configuration
|
||||
assert config.policy_kwargs.use_tanh_squash is True
|
||||
assert config.policy_kwargs.log_std_min == 1e-5
|
||||
assert config.policy_kwargs.log_std_max == 10.0
|
||||
assert config.policy_kwargs.init_final == 0.05
|
||||
|
||||
# Discrete critic network configuration
|
||||
assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.discrete_critic_network_kwargs.activate_final is True
|
||||
assert config.discrete_critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor learner configuration
|
||||
assert config.actor_learner_config.learner_host == "127.0.0.1"
|
||||
assert config.actor_learner_config.learner_port == 50051
|
||||
assert config.actor_learner_config.policy_parameters_push_frequency == 4
|
||||
|
||||
# Concurrency configuration
|
||||
assert config.concurrency.actor == "threads"
|
||||
assert config.concurrency.learner == "threads"
|
||||
|
||||
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
|
||||
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
|
||||
assert isinstance(config.policy_kwargs, PolicyConfig)
|
||||
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
|
||||
assert isinstance(config.concurrency, ConcurrencyConfig)
|
||||
|
||||
|
||||
def test_critic_network_kwargs():
|
||||
config = CriticNetworkConfig()
|
||||
assert config.hidden_dims == [256, 256]
|
||||
assert config.activate_final is True
|
||||
assert config.final_activation is None
|
||||
|
||||
|
||||
def test_actor_network_kwargs():
|
||||
config = ActorNetworkConfig()
|
||||
assert config.hidden_dims == [256, 256]
|
||||
assert config.activate_final is True
|
||||
|
||||
|
||||
def test_policy_kwargs():
|
||||
config = PolicyConfig()
|
||||
assert config.use_tanh_squash is True
|
||||
assert config.log_std_min == 1e-5
|
||||
assert config.log_std_max == 10.0
|
||||
assert config.init_final == 0.05
|
||||
|
||||
|
||||
def test_actor_learner_config():
|
||||
config = ActorLearnerConfig()
|
||||
assert config.learner_host == "127.0.0.1"
|
||||
assert config.learner_port == 50051
|
||||
assert config.policy_parameters_push_frequency == 4
|
||||
|
||||
|
||||
def test_concurrency_config():
|
||||
config = ConcurrencyConfig()
|
||||
assert config.actor == "threads"
|
||||
assert config.learner == "threads"
|
||||
|
||||
|
||||
def test_sac_config_custom_initialization():
|
||||
config = SACConfig(
|
||||
device="cpu",
|
||||
discount=0.95,
|
||||
temperature_init=0.5,
|
||||
num_critics=3,
|
||||
)
|
||||
|
||||
assert config.device == "cpu"
|
||||
assert config.discount == 0.95
|
||||
assert config.temperature_init == 0.5
|
||||
assert config.num_critics == 3
|
||||
|
||||
|
||||
def test_validate_features():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_observation():
|
||||
config = SACConfig(
|
||||
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="You must provide either 'observation.state' or an image observation"
|
||||
):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_action():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
|
||||
config.validate_features()
|
||||
@@ -1,519 +0,0 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
|
||||
TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRANSFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
def test_mlp_with_default_args():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
|
||||
x = torch.randn(10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (256,)
|
||||
|
||||
|
||||
def test_mlp_with_batch_dim():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
x = torch.randn(2, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (2, 256)
|
||||
|
||||
|
||||
def test_forward_with_empty_hidden_dims():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[])
|
||||
x = torch.randn(1, 10)
|
||||
assert mlp(x).shape == (1, 10)
|
||||
|
||||
|
||||
def test_mlp_with_dropout():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 11)
|
||||
|
||||
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
|
||||
assert drop_out_layers_count == 2
|
||||
|
||||
|
||||
def test_mlp_with_custom_final_activation():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 256)
|
||||
assert (y >= -1).all() and (y <= 1).all()
|
||||
|
||||
|
||||
def test_sac_policy_with_default_args():
|
||||
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
|
||||
SACPolicy()
|
||||
|
||||
|
||||
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
|
||||
return torch.randn(batch_size, action_dim)
|
||||
|
||||
|
||||
def create_default_train_batch(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
"action": create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_state(batch_size, state_dim),
|
||||
"next_state": create_dummy_state(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_train_batch_with_visual_input(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
"action": create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
}
|
||||
|
||||
|
||||
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Create optimizers for the SAC policy."""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha],
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
|
||||
if has_discrete_action:
|
||||
optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def create_default_config(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
action_dim = continuous_action_dim
|
||||
if has_discrete_action:
|
||||
action_dim += 1
|
||||
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
dataset_stats={
|
||||
"observation.state": {
|
||||
"min": [0.0] * state_dim,
|
||||
"max": [1.0] * state_dim,
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0] * continuous_action_dim,
|
||||
"max": [1.0] * continuous_action_dim,
|
||||
},
|
||||
},
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def create_config_with_visual_input(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
config = create_default_config(
|
||||
state_dim=state_dim,
|
||||
continuous_action_dim=continuous_action_dim,
|
||||
has_discrete_action=has_discrete_action,
|
||||
)
|
||||
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats["observation.image"] = {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
# Let make tests a little bit faster
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.latent_dim = 32
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
# Let's check best candidates for pretrained encoders
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||
)
|
||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||
def test_sac_policy_with_pretrained_encoder(
|
||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||
):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.vision_encoder_name = vision_encoder_name
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
|
||||
def test_sac_policy_with_shared_encoder():
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.shared_encoder = True
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
|
||||
def test_sac_policy_with_discrete_critic():
|
||||
batch_size = 2
|
||||
continuous_action_dim = 9
|
||||
full_action_dim = continuous_action_dim + 1 # the last action is discrete
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(
|
||||
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
|
||||
)
|
||||
|
||||
num_discrete_actions = 5
|
||||
config.num_discrete_actions = num_discrete_actions
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy, has_discrete_action=True)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
|
||||
assert discrete_critic_loss.item() is not None
|
||||
assert discrete_critic_loss.shape == ()
|
||||
discrete_critic_loss.backward()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, full_action_dim)
|
||||
|
||||
discrete_actions = selected_action[:, -1].long()
|
||||
discrete_action_values = set(discrete_actions.tolist())
|
||||
|
||||
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
|
||||
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
|
||||
)
|
||||
|
||||
|
||||
def test_sac_policy_with_default_entropy():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -5.0
|
||||
|
||||
|
||||
def test_sac_policy_default_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -3.0
|
||||
|
||||
|
||||
def test_sac_policy_with_predefined_entropy():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.target_entropy = -3.5
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == pytest.approx(-3.5)
|
||||
|
||||
|
||||
def test_sac_policy_update_temperature():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
assert policy.temperature == pytest.approx(1.0)
|
||||
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
policy.update_temperature()
|
||||
assert policy.temperature == pytest.approx(0.1)
|
||||
|
||||
|
||||
def test_sac_policy_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.critic_target_update_weight = 1.0
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
for p in policy.critic_ensemble.parameters():
|
||||
p.data = torch.ones_like(p.data)
|
||||
|
||||
policy.update_target_networks()
|
||||
for p in policy.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data)), (
|
||||
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_critics", [1, 3])
|
||||
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_critics = num_critics
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
assert len(policy.critic_ensemble.critics) == num_critics
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load(tmp_path):
|
||||
root = tmp_path / "test_sac_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
batch_size = 2
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Collect policy values before saving
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
actions = policy.select_action(observation_batch)
|
||||
|
||||
with seeded_context(12):
|
||||
# Collect policy values after loading
|
||||
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
|
||||
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
|
||||
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
# They should be the same
|
||||
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
|
||||
assert torch.allclose(actor_loss, loaded_actor_loss)
|
||||
assert torch.allclose(temperature_loss, loaded_temperature_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
@@ -1,599 +0,0 @@
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.scripts.server.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def state_dims() -> list[str]:
|
||||
return ["observation.image", "observation.state"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def replay_buffer() -> ReplayBuffer:
|
||||
return create_empty_replay_buffer()
|
||||
|
||||
|
||||
def clone_state(state: dict) -> dict:
|
||||
return {k: v.clone() for k, v in state.items()}
|
||||
|
||||
|
||||
def create_empty_replay_buffer(
|
||||
optimize_memory: bool = False,
|
||||
use_drq: bool = False,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
) -> ReplayBuffer:
|
||||
buffer_capacity = 10
|
||||
device = "cpu"
|
||||
return ReplayBuffer(
|
||||
buffer_capacity,
|
||||
device,
|
||||
state_dims(),
|
||||
optimize_memory=optimize_memory,
|
||||
use_drq=use_drq,
|
||||
image_augmentation_function=image_augmentation_function,
|
||||
)
|
||||
|
||||
|
||||
def create_random_image() -> torch.Tensor:
|
||||
return torch.rand(3, 84, 84)
|
||||
|
||||
|
||||
def create_dummy_transition() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"action": torch.randn(4),
|
||||
"reward": torch.tensor(1.0),
|
||||
"observation.state": torch.randn(
|
||||
10,
|
||||
),
|
||||
"done": torch.tensor(False),
|
||||
"truncated": torch.tensor(False),
|
||||
"complementary_info": {},
|
||||
}
|
||||
|
||||
|
||||
def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]:
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
root = tmp_path / "test"
|
||||
return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer)
|
||||
|
||||
|
||||
def create_dummy_state() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"observation.state": torch.randn(
|
||||
10,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_tensor_memory_consumption(tensor):
|
||||
return tensor.nelement() * tensor.element_size()
|
||||
|
||||
|
||||
def get_tensors_memory_consumption(obj, visited_addresses):
|
||||
total_size = 0
|
||||
|
||||
address = id(obj)
|
||||
if address in visited_addresses:
|
||||
return 0
|
||||
|
||||
visited_addresses.add(address)
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return get_tensor_memory_consumption(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
total_size += get_tensors_memory_consumption(item, visited_addresses)
|
||||
elif isinstance(obj, dict):
|
||||
for value in obj.values():
|
||||
total_size += get_tensors_memory_consumption(value, visited_addresses)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# It's an object, we need to get the size of the attributes
|
||||
for _, attr in vars(obj).items():
|
||||
total_size += get_tensors_memory_consumption(attr, visited_addresses)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def get_object_memory(obj):
|
||||
# Track visited addresses to avoid infinite loops
|
||||
# and cases when two properties point to the same object
|
||||
visited_addresses = set()
|
||||
|
||||
# Get the size of the object in bytes
|
||||
total_size = sys.getsizeof(obj)
|
||||
|
||||
# Get the size of the tensor attributes
|
||||
total_size += get_tensors_memory_consumption(obj, visited_addresses)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def create_dummy_action() -> torch.Tensor:
|
||||
return torch.randn(4)
|
||||
|
||||
|
||||
def dict_properties() -> list:
|
||||
return ["state", "next_state"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_state() -> dict:
|
||||
return create_dummy_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def next_dummy_state() -> dict:
|
||||
return create_dummy_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_action() -> torch.Tensor:
|
||||
return torch.randn(4)
|
||||
|
||||
|
||||
def test_empty_buffer_sample_raises_error(replay_buffer):
|
||||
assert len(replay_buffer) == 0, "Replay buffer should be empty."
|
||||
assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10."
|
||||
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
|
||||
replay_buffer.sample(1)
|
||||
|
||||
|
||||
def test_zero_capacity_buffer_raises_error():
|
||||
with pytest.raises(ValueError, match="Capacity must be greater than 0."):
|
||||
ReplayBuffer(0, "cpu", ["observation", "next_observation"])
|
||||
|
||||
|
||||
def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding."
|
||||
assert torch.equal(replay_buffer.actions[0], dummy_action), (
|
||||
"Action should be equal to the first transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
|
||||
assert not replay_buffer.dones[0], "Done should be False for the first transition."
|
||||
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
|
||||
"Observation should be equal to the first transition."
|
||||
)
|
||||
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), (
|
||||
"Next observation should be equal to the first transition."
|
||||
)
|
||||
|
||||
|
||||
def test_add_over_capacity():
|
||||
replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"])
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
|
||||
assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), (
|
||||
"Observation should be equal to the first transition."
|
||||
)
|
||||
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), (
|
||||
"Next observation should be equal to the first transition."
|
||||
)
|
||||
|
||||
assert torch.equal(replay_buffer.actions[0], dummy_action_3), (
|
||||
"Action should be equal to the last transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
|
||||
assert replay_buffer.dones[0], "Done should be True for the first transition."
|
||||
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
|
||||
|
||||
|
||||
def test_sample_from_empty_buffer(replay_buffer):
|
||||
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
|
||||
replay_buffer.sample(1)
|
||||
|
||||
|
||||
def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
|
||||
got_batch_transition = replay_buffer.sample(1)
|
||||
|
||||
expected_batch_transition = BatchTransition(
|
||||
state=clone_state(dummy_state),
|
||||
action=dummy_action.clone(),
|
||||
reward=1.0,
|
||||
next_state=clone_state(next_dummy_state),
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k, v in expected_batch_transition[buffer_property].items():
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
|
||||
assert got_state.device.type == "cpu", f"{k} should be on cpu."
|
||||
|
||||
assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition."
|
||||
|
||||
for key, _value in expected_batch_transition.items():
|
||||
if key in dict_properties():
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[key]
|
||||
|
||||
v_tensor = expected_batch_transition[key]
|
||||
if not isinstance(v_tensor, torch.Tensor):
|
||||
v_tensor = torch.tensor(v_tensor)
|
||||
|
||||
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
|
||||
assert got_value.device.type == "cpu", f"{key} should be on cpu."
|
||||
assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition."
|
||||
|
||||
|
||||
def test_sample_with_batch_bigger_than_buffer_size(
|
||||
replay_buffer, dummy_state, next_dummy_state, dummy_action
|
||||
):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
|
||||
got_batch_transition = replay_buffer.sample(10)
|
||||
|
||||
expected_batch_transition = BatchTransition(
|
||||
state=dummy_state,
|
||||
action=dummy_action,
|
||||
reward=1.0,
|
||||
next_state=next_dummy_state,
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k in expected_batch_transition[buffer_property]:
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
|
||||
|
||||
for key in expected_batch_transition:
|
||||
if key in dict_properties():
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[key]
|
||||
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
|
||||
|
||||
|
||||
def test_sample_batch(replay_buffer):
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True)
|
||||
|
||||
dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4]
|
||||
dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4]
|
||||
|
||||
got_batch_transition = replay_buffer.sample(3)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k in got_batch_transition[buffer_property]:
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 3, f"{k} should have 3 transition."
|
||||
|
||||
for got_state_item in got_state:
|
||||
assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), (
|
||||
f"{k} should be equal to one of the dummy states."
|
||||
)
|
||||
|
||||
for got_action_item in got_batch_transition["action"]:
|
||||
assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), (
|
||||
"Actions should be equal to the dummy actions."
|
||||
)
|
||||
|
||||
for k in got_batch_transition:
|
||||
if k in dict_properties() or k == "complementary_info":
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[k]
|
||||
assert got_value.shape[0] == 3, f"{k} should have 3 transition."
|
||||
|
||||
|
||||
def test_to_lerobot_dataset_with_empty_buffer(replay_buffer):
|
||||
with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."):
|
||||
replay_buffer.to_lerobot_dataset("dummy_repo")
|
||||
|
||||
|
||||
def test_to_lerobot_dataset(tmp_path):
|
||||
ds, buffer = create_dataset_from_replay_buffer(tmp_path)
|
||||
|
||||
assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer"
|
||||
assert ds.fps == 1, "FPS should be 1"
|
||||
assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id"
|
||||
|
||||
for dim in state_dims():
|
||||
assert dim in ds.features
|
||||
assert ds.features[dim]["shape"] == buffer.states[dim][0].shape
|
||||
|
||||
assert ds.num_episodes == 2
|
||||
assert ds.num_frames == 4
|
||||
|
||||
for j, value in enumerate(ds):
|
||||
print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j]))
|
||||
|
||||
for i in range(len(ds)):
|
||||
for feature, value in ds[i].items():
|
||||
if feature == "action":
|
||||
assert torch.equal(value, buffer.actions[i])
|
||||
elif feature == "next.reward":
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == "next.done":
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
elif feature == "observation.image":
|
||||
# Tenssor -> numpy is not precise, so we have some diff there
|
||||
# TODO: Check and fix it
|
||||
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
|
||||
elif feature == "observation.state":
|
||||
assert torch.equal(value, buffer.states["observation.state"][i])
|
||||
|
||||
|
||||
def test_from_lerobot_dataset(tmp_path):
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
root = tmp_path / "test"
|
||||
ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root)
|
||||
|
||||
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
|
||||
)
|
||||
|
||||
# Check only the part of the buffer that's actually filled with data
|
||||
assert torch.equal(
|
||||
reconverted_buffer.actions[: len(replay_buffer)],
|
||||
replay_buffer.actions[: len(replay_buffer)],
|
||||
), "Actions from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
|
||||
), "Rewards from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
|
||||
), "Dones from converted buffer should be equal to the original replay buffer."
|
||||
|
||||
# Lerobot DS haven't supported truncateds yet
|
||||
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
|
||||
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
|
||||
"Truncateds from converted buffer should be equal False"
|
||||
)
|
||||
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
), "State should be the same after converting to dataset and return back"
|
||||
|
||||
for i in range(4):
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][i],
|
||||
reconverted_buffer.states["observation.image"][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
# The 2, 3 frames have done flag, so their values will be equal to the current state
|
||||
for i in range(2):
|
||||
# In the current implementation we take the next state from the `states` and ignore `next_states`
|
||||
next_index = (i + 1) % 4
|
||||
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][next_index],
|
||||
reconverted_buffer.next_states["observation.image"][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
for i in range(2, 4):
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][i],
|
||||
reconverted_buffer.next_states["observation.state"][i],
|
||||
)
|
||||
|
||||
|
||||
def test_buffer_sample_alignment():
|
||||
# Initialize buffer
|
||||
buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu")
|
||||
|
||||
# Fill buffer with patterned data
|
||||
for i in range(100):
|
||||
signature = float(i) / 100.0
|
||||
state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
action = torch.tensor([[2.0 * signature]]).float()
|
||||
reward = 3.0 * signature
|
||||
|
||||
is_end = (i + 1) % 10 == 0
|
||||
if is_end:
|
||||
next_state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
done = True
|
||||
else:
|
||||
next_signature = float(i + 1) / 100.0
|
||||
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
|
||||
done = False
|
||||
|
||||
buffer.add(state, action, reward, next_state, done, False)
|
||||
|
||||
# Sample and verify
|
||||
batch = buffer.sample(50)
|
||||
|
||||
for i in range(50):
|
||||
state_sig = batch["state"]["state_value"][i].item()
|
||||
action_val = batch["action"][i].item()
|
||||
reward_val = batch["reward"][i].item()
|
||||
next_state_sig = batch["next_state"]["state_value"][i].item()
|
||||
is_done = batch["done"][i].item() > 0.5
|
||||
|
||||
# Verify relationships
|
||||
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
|
||||
f"Action {action_val} should be 2x state signature {state_sig}"
|
||||
)
|
||||
|
||||
assert abs(reward_val - 3.0 * state_sig) < 1e-4, (
|
||||
f"Reward {reward_val} should be 3x state signature {state_sig}"
|
||||
)
|
||||
|
||||
if is_done:
|
||||
assert abs(next_state_sig - state_sig) < 1e-4, (
|
||||
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
|
||||
)
|
||||
else:
|
||||
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
|
||||
valid_next = (
|
||||
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
|
||||
)
|
||||
assert valid_next, (
|
||||
f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}"
|
||||
)
|
||||
|
||||
|
||||
def test_memory_optimization():
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
optimized_replay_buffer = create_empty_replay_buffer(True)
|
||||
optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True)
|
||||
|
||||
assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), (
|
||||
"Optimized replay buffer should be smaller than the original replay buffer"
|
||||
)
|
||||
|
||||
|
||||
def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action):
|
||||
def dummy_image_augmentation_function(x):
|
||||
return torch.ones_like(x) * 10
|
||||
|
||||
replay_buffer = create_empty_replay_buffer(
|
||||
use_drq=True, image_augmentation_function=dummy_image_augmentation_function
|
||||
)
|
||||
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert torch.all(sampled_transitions["state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
|
||||
|
||||
def test_check_image_augmentations_with_drq_and_default_image_augmentation_function(
|
||||
dummy_state, dummy_action
|
||||
):
|
||||
replay_buffer = create_empty_replay_buffer(use_drq=True)
|
||||
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
# Let's check that it doesn't fail and shapes are correct
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_basic():
|
||||
# Create a batch of 2 images with known patterns
|
||||
batch_size, channels, height, width = 2, 3, 10, 8
|
||||
images = torch.zeros((batch_size, channels, height, width))
|
||||
|
||||
# Fill with unique values for testing
|
||||
for b in range(batch_size):
|
||||
images[b] = b + 1
|
||||
|
||||
crop_size = (6, 4) # Smaller than original
|
||||
cropped = random_crop_vectorized(images, crop_size)
|
||||
|
||||
# Check output shape
|
||||
assert cropped.shape == (batch_size, channels, *crop_size)
|
||||
|
||||
# Check that values are preserved (should be either 1s or 2s for respective batches)
|
||||
assert torch.all(cropped[0] == 1)
|
||||
assert torch.all(cropped[1] == 2)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_invalid_size():
|
||||
images = torch.zeros((2, 3, 10, 8))
|
||||
|
||||
# Test crop size larger than image
|
||||
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
|
||||
random_crop_vectorized(images, (12, 8))
|
||||
|
||||
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
|
||||
random_crop_vectorized(images, (10, 10))
|
||||
Reference in New Issue
Block a user