Compare commits

..

13 Commits

Author SHA1 Message Date
pre-commit-ci[bot]
8d20ee7655 [pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.11.11 → v0.11.12](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.11...v0.11.12)
- [github.com/gitleaks/gitleaks: v8.26.0 → v8.27.0](https://github.com/gitleaks/gitleaks/compare/v8.26.0...v8.27.0)
- [github.com/woodruffw/zizmor-pre-commit: v1.8.0 → v1.9.0](https://github.com/woodruffw/zizmor-pre-commit/compare/v1.8.0...v1.9.0)
2025-06-02 18:33:11 +00:00
pre-commit-ci[bot]
1537d0ab90 [pre-commit.ci] pre-commit autoupdate (#1048)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
2025-06-02 19:30:39 +02:00
Adil Zouitine
2be7f3a3ff (hotfix): nightly CI by clipping pymunk version below 7.0.0 (#1182) 2025-06-02 13:18:02 +02:00
Adil Zouitine
0cf864870c [Fix] Unpin torch beyond 2.6.0 & torchcodec beyond 0.2.1 (#1127) 2025-05-28 16:54:20 +02:00
mshukor
1786916a16 Update README.md (#1163) 2025-05-27 11:50:43 +02:00
mshukor
0507ad4f68 Update README.md (#1160) 2025-05-27 11:45:07 +02:00
Ragnar
bed90e3a41 fix: typos and grammar (#1148) 2025-05-25 17:20:45 +02:00
Francesco Capuano
6163daaaa4 Fix: emptying action queue between resets (#1117) 2025-05-22 21:37:21 +02:00
Pepijn
8e2a394442 Add editable -e for feetech install command (#1133) 2025-05-20 18:51:21 +02:00
masato-ka
a445d9c9da bug fix for #1071 When --display_data=true, Failed running control_robot. (#1073) 2025-05-09 16:53:40 +02:00
CharlesCNorton
f24030d4d8 Update 12_use_so101.md (#1081)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2025-05-09 11:04:25 +02:00
Mishig
7598aeaad7 Update 10_use_so100.md; use diff syntax (#944)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2025-05-09 11:01:12 +02:00
Pepijn
4485cc0b5b docs: minor corrections and clean-up (#1089) 2025-05-09 11:00:25 +02:00
71 changed files with 143 additions and 11658 deletions

View File

@@ -40,24 +40,24 @@ jobs:
git lfs install
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
lfs: true
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@v3
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push CPU
uses: docker/build-push-action@v5
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
with:
context: .
file: ./docker/lerobot-cpu/Dockerfile
@@ -78,24 +78,24 @@ jobs:
git lfs install
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
lfs: true
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@v3
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
uses: docker/build-push-action@v5
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
with:
context: .
file: ./docker/lerobot-gpu/Dockerfile
@@ -110,23 +110,23 @@ jobs:
group: aws-general-8-plus
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@v3
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU dev
uses: docker/build-push-action@v5
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
with:
context: .
file: ./docker/lerobot-gpu-dev/Dockerfile

View File

@@ -33,7 +33,7 @@ jobs:
runs-on:
group: aws-general-8-plus
container:
image: huggingface/lerobot-cpu:latest
image: huggingface/lerobot-cpu:latest # zizmor: ignore[unpinned-images]
options: --shm-size "16gb"
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -60,7 +60,7 @@ jobs:
CUDA_VISIBLE_DEVICES: "0"
TEST_TYPE: "single_gpu"
container:
image: huggingface/lerobot-gpu:latest
image: huggingface/lerobot-gpu:latest # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}

View File

@@ -33,12 +33,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1
with:
python-version: ${{ env.PYTHON_VERSION }}
@@ -64,9 +64,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: typos-action
uses: crate-ci/typos@v1.29.10
uses: crate-ci/typos@db35ee91e80fbb447f33b0e5fbddb24d2a1a884f # v1.29.10

View File

@@ -35,7 +35,7 @@ jobs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
@@ -64,17 +64,17 @@ jobs:
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Build Docker image
uses: docker/build-push-action@v5
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
with:
file: ${{ matrix.docker-file }}
context: .

View File

@@ -50,7 +50,7 @@ jobs:
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
@@ -62,7 +62,7 @@ jobs:
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
- name: Install uv and python
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
with:
enable-cache: true
version: ${{ env.UV_VERSION }}
@@ -85,7 +85,7 @@ jobs:
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
@@ -94,7 +94,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y ffmpeg
- name: Install uv and python
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
with:
enable-cache: true
version: ${{ env.UV_VERSION }}
@@ -117,7 +117,7 @@ jobs:
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
@@ -129,7 +129,7 @@ jobs:
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
- name: Install uv and python
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
with:
enable-cache: true
version: ${{ env.UV_VERSION }}

View File

@@ -24,12 +24,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
persist-credentials: false
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35
with:
extra_args: --only-verified

1
.gitignore vendored
View File

@@ -26,7 +26,6 @@ outputs
# VS Code
.vscode
.devcontainer
# HPC
nautilus/*.yaml

View File

@@ -37,19 +37,18 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/adhtruong/mirrors-typos
rev: v1.31.1
rev: v1.32.0
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
rev: v3.20.0
hooks:
- id: pyupgrade
# Exclude generated protobuf files
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.5
rev: v0.11.12
hooks:
- id: ruff
args: [--fix]
@@ -58,12 +57,12 @@ repos:
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.3
rev: v8.27.0
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.5.2
rev: v1.9.0
hooks:
- id: zizmor

View File

@@ -360,7 +360,7 @@ with profile(
If you want, you can cite this work with:
```bibtex
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
@@ -408,19 +408,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
year={2024}
}
```
- [HIL-SERL](https://hil-serl.github.io/)
```bibtex
@Article{luo2024hilserl,
title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine},
year={2024},
eprint={2410.21845},
archivePrefix={arXiv},
primaryClass={cs.RO}
}
```
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline)

View File

@@ -96,8 +96,8 @@ Reconnect the usb cable.
#### Update config file
Now that you have your ports, update the **port** default values of [`SO101RobotConfig`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robot_devices/robots/configs.py).
You will find something like, update the `port` values with your actual motor ports:
```python
You will find a class called `so101` where you can update the `port` values with your actual motor ports:
```diff
@RobotConfig.register_subclass("so101")
@dataclass
class So101RobotConfig(ManipulatorRobotConfig):
@@ -110,7 +110,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
- port="/dev/tty.usbmodem58760431091",
+ port="{ADD YOUR LEADER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -127,7 +128,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
- port="/dev/tty.usbmodem585A0076891",
+ port="{ADD YOUR FOLLOWER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -297,7 +299,7 @@ Remove all support material from the 3D-printed parts, the easiest way to do thi
##### Wiring
- Attach the motor controller on the back.
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themself and stay in place.
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themselves and stay in place.
<div class="video-container">
<video controls width="600">

View File

@@ -83,7 +83,7 @@ camera_01_frame_000047.png
Note: Some cameras may take a few seconds to warm up, and the first frame might be black or green.
Now that you have the camera indexes, you should specify the camera's in the config. TODO(pepijn): add more info about setting camera config, rotate etc..
Now that you have the camera indexes, you should specify the camera's in the config.
### Use your phone
<hfoptions id="use phone">
@@ -152,7 +152,7 @@ If everything is set up correctly, you can proceed with the rest of the tutorial
## Teleoperate with cameras
We can now teleoperate again while at the same time visualizing the camera's and joint positions with `rerun`.
We can now teleoperate again while at the same time visualizing the cameras and joint positions with `rerun`.
```bash
python lerobot/scripts/control_robot.py \

View File

@@ -55,7 +55,7 @@ conda install ffmpeg -c conda-forge
Install 🤗 LeRobot:
```bash
cd lerobot && pip install ".[feetech]"
cd lerobot && pip install -e ".[feetech]"
```
## Troubleshooting

View File

@@ -128,7 +128,7 @@ sudo chmod 666 /dev/ttyACM1
#### d. Update config file
IMPORTANTLY: Now that you have your ports, update the **port** default values of [`SO100RobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
```python
```diff
@RobotConfig.register_subclass("so100")
@dataclass
class So100RobotConfig(ManipulatorRobotConfig):
@@ -141,7 +141,8 @@ class So100RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
- port="/dev/tty.usbmodem58760431091",
+ port="{ADD YOUR LEADER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -158,7 +159,8 @@ class So100RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
- port="/dev/tty.usbmodem585A0076891",
+ port="{ADD YOUR FOLLOWER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],

View File

@@ -141,7 +141,7 @@ python lerobot/scripts/configure_motor.py \
--ID 1
```
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
Note: These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
Then unplug your motor and plug the second motor and set its ID to 2.
```bash

View File

@@ -61,7 +61,7 @@ conda install ffmpeg -c conda-forge
Install 🤗 LeRobot:
```bash
cd lerobot && pip install ".[feetech]"
cd lerobot && pip install -e ".[feetech]"
```
> [!NOTE]
@@ -145,8 +145,8 @@ Reconnect the usb cable.
#### Update config file
Now that you have your ports, update the **port** default values of [`SO101RobotConfig`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robot_devices/robots/configs.py).
You will find something a class called `so101` where you can update the `port` values with your actual motor ports:
```python
You will find a class called `so101` where you can update the `port` values with your actual motor ports:
```diff
@RobotConfig.register_subclass("so101")
@dataclass
class So101RobotConfig(ManipulatorRobotConfig):
@@ -159,7 +159,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
- port="/dev/tty.usbmodem58760431091",
+ port="{ADD YOUR LEADER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -176,7 +177,8 @@ class So101RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
- port="/dev/tty.usbmodem585A0076891",
+ port="{ADD YOUR FOLLOWER PORT}",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -308,7 +310,7 @@ Remove all support material from the 3D-printed parts.
##### Wiring
- Attach the motor controller on the back.
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themself and stay in place.
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themselves and stay in place.
<video controls width="640" src="https://github.com/user-attachments/assets/4c2cacfd-9276-4ee4-8bf2-ba2492667b78" type="video/mp4"></video>
@@ -351,7 +353,7 @@ python lerobot/scripts/control_robot.py \
```
## Control your robot
Congrats 🎉, your robot is all set to learn a task on its own. Next we will explain you how to train a neural network to autonomously control a real robot.
Congrats 🎉, your robot is all set to learn a task on its own. Next we will explain to you how to train a neural network to autonomously control a real robot.
**You'll learn to:**
1. How to record and visualize your dataset.
@@ -515,7 +517,7 @@ If you have an additional camera you can add a wrist camera to the SO101. There
## Teleoperate with cameras
We can now teleoperate again while at the same time visualizing the camera's and joint positions with `rerun`.
We can now teleoperate again while at the same time visualizing the cameras and joint positions with `rerun`.
```bash
python lerobot/scripts/control_robot.py \
@@ -526,7 +528,7 @@ python lerobot/scripts/control_robot.py \
## Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
Once you're familiar with teleoperation, you can record your first dataset with SO-101.
We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens).

View File

@@ -106,7 +106,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
class AsyncImageWriter:
"""
This class abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
save images on disk asynchronously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it creates a threads pool of size `num_threads`.

View File

@@ -944,7 +944,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def stop_image_writer(self) -> None:
"""
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized.
"""
if self.image_writer is not None:
self.image_writer.stop()

View File

@@ -101,7 +101,7 @@ def decode_video_frames_torchvision(
keyframes_only = False
torchvision.set_video_backend(backend)
if backend == "pyav":
keyframes_only = True # pyav doesnt support accuracte seek
keyframes_only = True # pyav doesn't support accurate seek
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time

View File

@@ -14,12 +14,10 @@
import abc
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import draccus
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -156,122 +154,3 @@ class XarmEnv(EnvConfig):
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}
@dataclass
class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments."""
enabled: bool = False
record_dir: str = "videos"
trajectory_name: str = "trajectory"
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
control_mode: str = "gamepad"
@dataclass
class EnvTransformConfig:
"""Configuration for environment wrappers."""
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
display_cameras: bool = False
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
use_gripper: bool = False
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvTransformConfig] = None
fps: int = 10
name: str = "real_robot"
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
reward_classifier_pretrained_path: Optional[str] = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
def gym_kwargs(self) -> dict:
return {}
@EnvConfig.register_subclass("hil")
@dataclass
class HILEnvConfig(EnvConfig):
"""Configuration for the HIL environment."""
type: str = "hil"
name: str = "PandaPickCube"
task: str = "PandaPickCubeKeyboard-v0"
use_viewer: bool = True
gripper_penalty: float = 0.0
use_gamepad: bool = True
state_dim: int = 18
action_dim: int = 4
fps: int = 100
episode_length: int = 100
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"observation.image": OBS_IMAGE,
"observation.state": OBS_ROBOT,
}
)
################# args from hilserlrobotenv
reward_classifier_pretrained_path: Optional[str] = None
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvTransformConfig] = None
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
############################
@property
def gym_kwargs(self) -> dict:
return {
"use_viewer": self.use_viewer,
"use_gamepad": self.use_gamepad,
"gripper_penalty": self.gripper_penalty,
}

View File

@@ -17,7 +17,7 @@ import importlib
import gymnasium as gym
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
elif env_type == "hil":
return HILEnvConfig(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
@@ -67,8 +65,5 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
env = env_cls(
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase
# https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/vector/vectorize_observation.py#L19,
# env = ObservationProcessorWrapper(env=env)
return env

View File

@@ -47,10 +47,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img)
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
# This is the case for human-in-the-loop RL where there is only one environment.
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
@@ -66,18 +62,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[imgkey] = img
if "environment_state" in observations:
env_state = torch.from_numpy(observations["environment_state"]).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
return_observations["observation.environment_state"] = env_state
return_observations["observation.environment_state"] = torch.from_numpy(
observations["environment_state"]
).float()
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations["observation.state"] = agent_pos
# requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return return_observations

View File

@@ -14,9 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
import draccus
import torch
@@ -45,16 +44,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
return "adam"
@abc.abstractmethod
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
NOTE: Multiple optimizers are useful when you have different models to optimize.
For example, you can have one optimizer for the policy and another one for the value function
in reinforcement learning settings.
Returns:
The optimizer or a dictionary of optimizers.
"""
def build(self) -> torch.optim.Optimizer:
raise NotImplementedError
@@ -104,76 +94,7 @@ class SGDConfig(OptimizerConfig):
return torch.optim.SGD(params, **kwargs)
@OptimizerConfig.register_subclass("multi_adam")
@dataclass
class MultiAdamConfig(OptimizerConfig):
"""Configuration for multiple Adam optimizers with different parameter groups.
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
Args:
lr: Default learning rate (used if not specified for a group)
weight_decay: Default weight decay (used if not specified for a group)
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
grad_clip_norm: Gradient clipping norm
"""
lr: float = 1e-3
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
"""Build multiple Adam optimizers.
Args:
params_dict: Dictionary mapping parameter group names to lists of parameters
The keys should match the keys in optimizer_groups
Returns:
Dictionary mapping parameter group names to their optimizers
"""
optimizers = {}
for name, params in params_dict.items():
# Get group-specific hyperparameters or use defaults
group_config = self.optimizer_groups.get(name, {})
# Create optimizer with merged parameters (defaults + group-specific)
optimizer_kwargs = {
"lr": group_config.get("lr", self.lr),
"betas": group_config.get("betas", (0.9, 0.999)),
"eps": group_config.get("eps", 1e-5),
"weight_decay": group_config.get("weight_decay", self.weight_decay),
}
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
return optimizers
def save_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> None:
"""Save optimizer state to disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to save the optimizer state.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
optimizer_dir.mkdir(exist_ok=True, parents=True)
_save_single_optimizer_state(opt, optimizer_dir)
else:
# Handle single optimizer
_save_single_optimizer_state(optimizer, save_dir)
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
"""Save a single optimizer's state to disk."""
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
state = optimizer.state_dict()
param_groups = state.pop("param_groups")
flat_state = flatten_dict(state)
@@ -181,44 +102,11 @@ def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""Load optimizer state from disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to load the optimizer state from.
Returns:
The updated optimizer(s) with loaded state.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
loaded_optimizers = {}
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
if optimizer_dir.exists():
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
else:
loaded_optimizers[name] = opt
return loaded_optimizers
else:
# Handle single optimizer
return _load_single_optimizer_state(optimizer, save_dir)
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
"""Load a single optimizer's state from disk."""
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
# Handle case where 'state' key might not exist (for newly created optimizers)
if "state" in state:
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
else:
loaded_state_dict = {"state": {}}
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
if "param_groups" in current_state_dict:
param_groups = deserialize_json_into_object(

View File

@@ -27,7 +27,6 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
@@ -60,14 +59,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "sac":
from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy
elif name == "reward_classifier":
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
return Classifier
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -85,8 +76,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

View File

@@ -151,7 +151,6 @@ class Normalize(nn.Module):
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# TODO: Remove this shallow copy
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
@@ -253,168 +252,3 @@ class Unnormalize(nn.Module):
else:
raise ValueError(norm_mode)
return batch
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
module: nn.Module,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> None:
"""Register statistics buffers (mean/std or min/max) on the given *module*.
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
but is factored out so it can be reused by both classes and stay in sync.
"""
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
shape: tuple[int, ...] = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# reduce spatial dimensions, keep channel dimension only
c, *_ = shape
shape = (c, 1, 1)
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.full(shape, torch.inf, dtype=torch.float32)
std = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"]
std_data = stats[key]["std"]
if isinstance(mean_data, torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_mean", mean)
module.register_buffer(f"{prefix}_std", std)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"]
max_data = stats[key]["max"]
if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_min", min_val)
module.register_buffer(f"{prefix}_max", max_val)
continue
raise ValueError(norm_mode)
class NormalizeBuffer(nn.Module):
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
batch[key] = batch[key] * 2 - 1
continue
raise ValueError(norm_mode)
return batch
class UnnormalizeBuffer(nn.Module):
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max_val - min_val) + min_val
continue
raise ValueError(norm_mode)
return batch

View File

@@ -357,7 +357,7 @@ class PI0Policy(PreTrainedPolicy):
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]

View File

@@ -516,7 +516,7 @@ class PI0FAST(nn.Module):
interpolate_like_pi=self.config.interpolate_like_pi,
)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]

View File

@@ -1,62 +0,0 @@
from dataclasses import dataclass, field
from typing import List
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass(name="reward_classifier")
@dataclass
class RewardClassifierConfig(PreTrainedConfig):
"""Configuration for the Reward Classifier model."""
name: str = "reward_classifier"
num_classes: int = 2
hidden_dim: int = 256
latent_dim: int = 256
image_embedding_pooling_dim: int = 8
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
learning_rate: float = 1e-4
weight_decay: float = 0.01
grad_clip_norm: float = 1.0
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
}
)
@property
def observation_delta_indices(self) -> List | None:
return None
@property
def action_delta_indices(self) -> List | None:
return None
@property
def reward_delta_indices(self) -> List | None:
return None
def get_optimizer_preset(self) -> OptimizerConfig:
return AdamWConfig(
lr=self.learning_rate,
weight_decay=self.weight_decay,
grad_clip_norm=self.grad_clip_norm,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return None
def validate_features(self) -> None:
"""Validate feature configurations."""
has_image = any(key.startswith("observation.image") for key in self.input_features)
if not has_image:
raise ValueError(
"You must provide an image observation (key starting with 'observation.image') in the input features"
)

View File

@@ -1,301 +0,0 @@
import logging
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
from lerobot.common.constants import OBS_IMAGE
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self,
logits: Tensor,
probabilities: Optional[Tensor] = None,
hidden_states: Optional[Tensor] = None,
):
self.logits = logits
self.probabilities = probabilities
self.hidden_states = hidden_states
def __repr__(self):
return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class SpatialLearnedEmbeddings(nn.Module):
def __init__(self, height, width, channel, num_features=8):
"""
PyTorch implementation of learned spatial embeddings
Args:
height: Spatial height of input features
width: Spatial width of input features
channel: Number of input channels
num_features: Number of output embedding dimensions
"""
super().__init__()
self.height = height
self.width = width
self.channel = channel
self.num_features = num_features
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
def forward(self, features):
"""
Forward pass for spatial embedding
Args:
features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch
Returns:
Output tensor of shape [B, C*F] or [C*F] if no batch
"""
features = features.last_hidden_state
original_shape = features.shape
if features.dim() == 3:
features = features.unsqueeze(0) # Add batch dim
features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1]
kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F]
# Element-wise multiplication and spatial reduction
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W
# Reshape to combine channel and feature dimensions
output = output.view(output.size(0), -1) # [B, C*F]
# Remove batch dim
if len(original_shape) == 3:
output = output.squeeze(0)
return output
class Classifier(PreTrainedPolicy):
"""Image classifier built on top of a pre-trained encoder."""
name = "reward_classifier"
config_class = RewardClassifierConfig
def __init__(
self,
config: RewardClassifierConfig,
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
):
from transformers import AutoModel
super().__init__(config)
self.config = config
# Initialize normalization (standardized with the policy framework)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Set up encoder
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
self.encoder = encoder.vision_model
self.vision_config = encoder.config.vision_config
else:
self.encoder = encoder
self.vision_config = getattr(encoder, "config", None)
# Model type from config
self.is_cnn = self.config.model_type == "cnn"
# For CNNs, initialize backbone
if self.is_cnn:
self._setup_cnn_backbone()
self._freeze_encoder()
# Extract image keys from input_features
self.image_keys = [
key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE)
]
if self.is_cnn:
self.encoders = nn.ModuleDict()
for image_key in self.image_keys:
encoder = self._create_single_encoder()
self.encoders[image_key] = encoder
self._build_classifier_head()
def _setup_cnn_backbone(self):
"""Set up CNN encoder"""
if hasattr(self.encoder, "fc"):
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
for param in self.encoder.parameters():
param.requires_grad = False
def _create_single_encoder(self):
encoder = nn.Sequential(
self.encoder,
SpatialLearnedEmbeddings(
height=4,
width=4,
channel=self.feature_dim,
num_features=self.config.image_embedding_pooling_dim,
),
nn.Dropout(self.config.dropout_rate),
nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
nn.Tanh(),
)
return encoder
def _build_classifier_head(self) -> None:
"""Initialize the classifier head architecture."""
# Get input dimension based on model type
if self.is_cnn:
input_dim = self.config.latent_dim
else: # Transformer models
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(
self.config.hidden_dim,
1 if self.config.num_classes == 2 else self.config.num_classes,
),
)
def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""
with torch.no_grad():
if self.is_cnn:
# The HF ResNet applies pooling internally
outputs = self.encoders[image_key](x)
return outputs
else: # Transformer models
outputs = self.encoder(x)
return outputs.last_hidden_state[:, 0, :]
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
"""Extract image tensors and label tensors from batch."""
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
labels = batch["next.reward"]
return images, labels
def predict(self, xs: list) -> ClassifierOutput:
"""Forward pass of the classifier for inference."""
encoder_outputs = torch.hstack(
[self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)]
)
logits = self.classifier_head(encoder_outputs)
if self.config.num_classes == 2:
logits = logits.squeeze(-1)
probabilities = torch.sigmoid(logits)
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
# Normalize inputs if needed
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images and labels
images, labels = self.extract_images_and_labels(batch)
# Get predictions
outputs = self.predict(images)
# Calculate loss
if self.config.num_classes == 2:
# Binary classification
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
else:
# Multi-class classification
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
predictions = torch.argmax(outputs.logits, dim=1)
# Calculate accuracy for logging
correct = (predictions == labels).sum().item()
total = labels.size(0)
accuracy = 100 * correct / total
# Return loss and metrics for logging
output_dict = {
"accuracy": accuracy,
"correct": correct,
"total": total,
}
return loss, output_dict
def predict_reward(self, batch, threshold=0.5):
"""Eval method. Returns predicted reward with the decision threshold as argument."""
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images from batch dict
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
if self.config.num_classes == 2:
probs = self.predict(images).probabilities
logging.debug(f"Predicted reward images: {probs}")
return (probs > threshold).float()
else:
return torch.argmax(self.predict(images).probabilities, dim=1)
def get_optim_params(self):
"""Return optimizer parameters for the policy."""
return self.parameters()
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
raise NotImplementedError("Reward classifiers do not select actions")
def reset(self):
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
pass

View File

@@ -1,243 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
def is_image_feature(key: str) -> bool:
"""Check if a feature key represents an image feature.
Args:
key: The feature key to check
Returns:
True if the key represents an image feature, False otherwise
"""
return key.startswith("observation.image")
@dataclass
class ConcurrencyConfig:
"""Configuration for the concurrency of the actor and learner.
Possible values are:
- "threads": Use threads for the actor and learner.
- "processes": Use processes for the actor and learner.
"""
actor: str = "threads"
learner: str = "threads"
@dataclass
class ActorLearnerConfig:
learner_host: str = "127.0.0.1"
learner_port: int = 50051
policy_parameters_push_frequency: int = 4
@dataclass
class CriticNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
final_activation: str | None = None
@dataclass
class ActorNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
@dataclass
class PolicyConfig:
use_tanh_squash: bool = True
log_std_min: float = 1e-5
log_std_max: float = 10.0
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac")
@dataclass
class SACConfig(PreTrainedConfig):
"""Soft Actor-Critic (SAC) configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
"""
# Mapping of feature types to normalization modes
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Statistics for normalizing different types of inputs
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
"observation.image": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
}
)
# Architecture specifics
# Device to run the model on (e.g., "cuda", "cpu")
device: str = "cpu"
# Device to store the model on
storage_device: str = "cpu"
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
vision_encoder_name: str | None = None
# Whether to freeze the vision encoder during training
freeze_vision_encoder: bool = True
# Hidden dimension size for the image encoder
image_encoder_hidden_dim: int = 32
# Whether to use a shared encoder for actor and critic
shared_encoder: bool = True
# Number of discrete actions, eg for gripper actions
num_discrete_actions: int | None = None
# Dimension of the image embedding pooling
image_embedding_pooling_dim: int = 8
# Training parameter
# Number of steps for online training
online_steps: int = 1000000
# Seed for the online environment
online_env_seed: int = 10000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
offline_buffer_capacity: int = 100000
# Whether to use asynchronous prefetching for the buffers
async_prefetch: bool = False
# Number of steps before learning starts
online_step_before_learning: int = 100
# Frequency of policy updates
policy_update_freq: int = 1
# SAC algorithm parameters
# Discount factor for the SAC algorithm
discount: float = 0.99
# Initial temperature value
temperature_init: float = 1.0
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
utd_ratio: int = 1
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Target entropy for the SAC algorithm
target_entropy: float | None = None
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Network configuration
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for actor-learner architecture
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations
use_torch_compile: bool = True
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration
def get_optimizer_preset(self) -> MultiAdamConfig:
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"actor": {"lr": self.actor_lr},
"critic": {"lr": self.critic_lr},
"temperature": {"lr": self.temperature_lr},
},
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
has_image = any(is_image_feature(key) for key in self.input_features)
has_state = "observation.state" in self.input_features
if not (has_state or has_image):
raise ValueError(
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
)
if "action" not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def image_features(self) -> list[str]:
return [key for key in self.input_features if is_image_feature(key)]
@property
def observation_delta_indices(self) -> list:
return None
@property
def action_delta_indices(self) -> list:
return None # SAC typically predicts one action at a time
@property
def reward_delta_indices(self) -> None:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -87,8 +87,6 @@ class RecordControlConfig(ControlConfig):
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Reset follower arms to an initial position.
reset_follower_arms: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.

View File

@@ -24,7 +24,6 @@ from contextlib import nullcontext
from copy import copy
from functools import cache
import numpy as np
import rerun as rr
import torch
from deepdiff import DeepDiff
@@ -130,11 +129,9 @@ def predict_action(observation, policy, device, use_amp):
def init_keyboard_listener():
"""
Initializes a keyboard listener to enable early termination of an episode
or environment reset by pressing the right arrow key ('->'). This may require
sudo permissions to allow the terminal to monitor keyboard events.
"""
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
@@ -163,7 +160,6 @@ def init_keyboard_listener():
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Error handling key press: {e}")
@@ -247,6 +243,11 @@ def control_loop(
timestamp = 0
start_episode_t = time.perf_counter()
# Controls starts, if policy is given it needs cleaning up
if policy is not None:
policy.reset()
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
@@ -254,13 +255,11 @@ def control_loop(
observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
action = None
if policy is not None:
pred_action = predict_action(
observation,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
@@ -273,9 +272,10 @@ def control_loop(
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
for k, v in action.items():
for i, vv in enumerate(v):
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
if action is not None:
for k, v in action.items():
for i, vv in enumerate(v):
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
@@ -308,17 +308,7 @@ def reset_environment(robot, events, reset_time_s, fps):
)
def reset_follower_position(robot_arm, target_position):
current_position = robot_arm.read("Present_Position")
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an arbitrary number
for pose in trajectory:
robot_arm.write("Goal_Position", pose)
busy_wait(0.015)
def stop_recording(robot, listener, display_cameras):
def stop_recording(robot, listener, display_data):
robot.disconnect()
if not is_headless() and listener is not None:
@@ -344,20 +334,12 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset,
robot: Robot,
fps: int,
use_videos: bool,
extra_features: dict = None,
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
) -> None:
features_from_robot = get_features_from_robot(robot, use_videos)
if extra_features is not None:
features_from_robot.update(extra_features)
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, features_from_robot),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
]
mismatches = []

View File

@@ -18,11 +18,9 @@ import os
import os.path as osp
import platform
import subprocess
import time
from copy import copy, deepcopy
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean
import numpy as np
import torch
@@ -109,17 +107,11 @@ def is_amp_available(device: str):
raise ValueError(f"Unknown device '{device}.")
def init_logging(log_file: Path | None = None, display_pid: bool = False):
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
# NOTE: Display PID is useful for multi-process logging.
if display_pid:
pid_str = f"[PID: {os.getpid()}]"
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
else:
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
return message
logging.basicConfig(level=logging.INFO)
@@ -133,12 +125,6 @@ def init_logging(log_file: Path | None = None, display_pid: bool = False):
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
if log_file is not None:
# Additionally write logs to file
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logging.getLogger().addHandler(file_handler)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]
@@ -242,114 +228,3 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
except TypeError:
# If a TypeError is raised, the string is not a valid dtype
return False
class TimerManager:
"""
Lightweight utility to measure elapsed time.
Examples
--------
```python
# Example 1: Using context manager
timer = TimerManager("Policy", log=False)
for _ in range(3):
with timer:
time.sleep(0.01)
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
```
```python
# Example 2: Using start/stop methods
timer = TimerManager("Policy", log=False)
timer.start()
time.sleep(0.01)
timer.stop()
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
```
"""
def __init__(
self,
label: str = "Elapsed-time",
log: bool = True,
logger: logging.Logger | None = None,
):
self.label = label
self.log = log
self.logger = logger
self._start: float | None = None
self._history: list[float] = []
def __enter__(self):
return self.start()
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def start(self):
self._start = time.perf_counter()
return self
def stop(self) -> float:
if self._start is None:
raise RuntimeError("Timer was never started.")
elapsed = time.perf_counter() - self._start
self._history.append(elapsed)
self._start = None
if self.log:
if self.logger is not None:
self.logger.info(f"{self.label}: {elapsed:.6f} s")
else:
logging.info(f"{self.label}: {elapsed:.6f} s")
return elapsed
def reset(self):
self._history.clear()
@property
def last(self) -> float:
return self._history[-1] if self._history else 0.0
@property
def avg(self) -> float:
return mean(self._history) if self._history else 0.0
@property
def total(self) -> float:
return sum(self._history)
@property
def count(self) -> int:
return len(self._history)
@property
def history(self) -> list[float]:
return deepcopy(self._history)
@property
def fps_history(self) -> list[float]:
return [1.0 / t for t in self._history]
@property
def fps_last(self) -> float:
return 0.0 if self.last == 0 else 1.0 / self.last
@property
def fps_avg(self) -> float:
return 0.0 if self.avg == 0 else 1.0 / self.avg
def percentile(self, p: float) -> float:
"""
Return the p-th percentile of recorded times.
"""
if not self._history:
return 0.0
return float(np.percentile(self._history, p))
def fps_percentile(self, p: float) -> float:
"""
FPS corresponding to the p-th percentile time.
"""
val = self.percentile(p)
return 0.0 if val == 0 else 1.0 / val

View File

@@ -30,10 +30,9 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.type}",
f"dataset:{cfg.dataset.repo_id}",
f"seed:{cfg.seed}",
]
if cfg.dataset is not None:
lst.append(f"dataset:{cfg.dataset.repo_id}")
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
return lst if return_list else "-".join(lst)
@@ -93,12 +92,6 @@ class WandBLogger:
resume="must" if cfg.resume else None,
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
)
run_id = wandb.run.id
# NOTE: We will override the cfg.wandb.run_id with the wandb run id.
# This is because we want to be able to resume the run from the wandb run id.
cfg.wandb.run_id = run_id
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@@ -115,26 +108,9 @@ class WandBLogger:
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
def log_dict(
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
):
def log_dict(self, d: dict, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:
raise ValueError(mode)
if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.")
# NOTE: This is not simple. Wandb step must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example,
# multiple time steps is possible. For example, the interaction step with the environment,
# the training step, the evaluation step, etc. So we need to define a custom step key
# to log the correct step for each metric.
if custom_step_key is not None:
if self._wandb_custom_step_key is None:
self._wandb_custom_step_key = set()
new_custom_key = f"{mode}/{custom_step_key}"
if new_custom_key not in self._wandb_custom_step_key:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items():
if not isinstance(v, (int, float, str)):
@@ -142,18 +118,7 @@ class WandBLogger:
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
# Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
self._wandb.log(data)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:

View File

@@ -34,10 +34,11 @@ TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
dataset: DatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None
job_name: str | None = None
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
@@ -106,7 +107,7 @@ class TrainPipelineConfig(HubMixin):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if self.dataset is not None and isinstance(self.dataset.repo_id, list):
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):

View File

@@ -23,7 +23,6 @@ class FeatureType(str, Enum):
VISUAL = "VISUAL"
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"
class NormalizationMode(str, Enum):

View File

@@ -1,722 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Actor server runner for distributed HILSerl robot policy training.
This script implements the actor component of the distributed HILSerl architecture.
It executes the policy in the robot environment, collects experience,
and sends transitions to the learner server for policy updates.
Examples of usage:
- Start an actor server for real robot training with human-in-the-loop intervention:
```bash
python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
```
- Run with a specific robot type for a pick and place task:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--robot.type=so100 \
--task=pick_and_place
```
- Set a custom workspace bound for the robot's end-effector:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
```
- Run with specific camera crop parameters:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
```
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
server is started before launching the actor.
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
reduce interventions as the policy improves.
**WORKFLOW**:
1. Determine robot workspace bounds using `find_joint_limits.py`
2. Record demonstrations with `gym_manipulator.py` in record mode
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
4. Start the learner server with the training configuration
5. Start this actor server with the same configuration
6. Use human interventions to guide policy learning
For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
import logging
import os
import time
from functools import lru_cache
from queue import Empty
import grpc
import torch
from torch import nn
from torch.multiprocessing import Event, Queue
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device,
init_logging,
)
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
from lerobot.scripts.server.buffer import Transition
from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.scripts.server.network_utils import (
bytes_to_state_dict,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.scripts.server.utils import (
get_last_item_from_queue,
move_state_dict_to_device,
move_transition_to_device,
setup_process_handlers,
)
ACTOR_SHUTDOWN_TIMEOUT = 30
#################################################
# Main entry point #
#################################################
@parser.wrap()
def actor_cli(cfg: TrainPipelineConfig):
cfg.validate()
display_pid = False
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
display_pid = True
# Create logs directory to ensure it exists
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=display_pid)
logging.info(f"Actor logging initialized, writing to {log_file}")
shutdown_event = setup_process_handlers(use_threads(cfg))
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
logging.info("[ACTOR] Establishing connection with Learner")
if not establish_learner_connection(learner_client, shutdown_event):
logging.error("[ACTOR] Failed to establish connection with Learner")
return
if not use_threads(cfg):
# If we use multithreading, we can reuse the channel
grpc_channel.close()
grpc_channel = None
logging.info("[ACTOR] Connection with Learner established")
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
concurrency_entity = None
if use_threads(cfg):
from threading import Thread
concurrency_entity = Thread
else:
from multiprocessing import Process
concurrency_entity = Process
receive_policy_process = concurrency_entity(
target=receive_policy,
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process = concurrency_entity(
target=send_transitions,
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
daemon=True,
)
interactions_process = concurrency_entity(
target=send_interactions,
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process.start()
interactions_process.start()
receive_policy_process.start()
act_with_policy(
cfg=cfg,
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
transitions_queue=transitions_queue,
interactions_queue=interactions_queue,
)
logging.info("[ACTOR] Policy process joined")
logging.info("[ACTOR] Closing queues")
transitions_queue.close()
interactions_queue.close()
parameters_queue.close()
transitions_process.join()
logging.info("[ACTOR] Transitions process joined")
interactions_process.join()
logging.info("[ACTOR] Interactions process joined")
receive_policy_process.join()
logging.info("[ACTOR] Receive policy process joined")
logging.info("[ACTOR] join queues")
transitions_queue.cancel_join_thread()
interactions_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[ACTOR] queues closed")
#################################################
# Core algorithm functions #
#################################################
def act_with_policy(
cfg: TrainPipelineConfig,
shutdown_event: any, # Event,
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
):
"""
Executes policy interaction within the environment.
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
Args:
cfg: Configuration settings for the interaction process.
shutdown_event: Event to check if the process should shutdown.
parameters_queue: Queue to receive updated network parameters from the learner.
transitions_queue: Queue to send transitions to the learner.
interactions_queue: Queue to send interactions to the learner.
"""
# Initialize logging for multiprocessing
if not use_threads(cfg):
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor policy process logging initialized")
logging.info("make_env online")
online_env = make_robot_env(cfg=cfg.env)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy: SACPolicy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
list_transition_to_send_to_learner = []
episode_intervention = False
# Add counters for intervention rate calculation
episode_intervention_steps = 0
episode_total_steps = 0
policy_timer = TimerManager("Policy inference", log=False)
for interaction_step in range(cfg.policy.online_steps):
start_time = time.perf_counter()
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down act_with_policy")
return
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with policy_timer:
action = policy.select_action(batch=obs)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate
episode_total_steps += 1
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
episode_intervention = True
# Increment intervention steps counter
episode_intervention_steps += 1
list_transition_to_send_to_learner.append(
Transition(
state=obs,
action=action,
reward=reward,
next_state=next_obs,
done=done,
truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info,
)
)
# assign obs to the next obs and continue the rollout
obs = next_obs
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
transitions=list_transition_to_send_to_learner,
transitions_queue=transitions_queue,
)
list_transition_to_send_to_learner = []
stats = get_frequency_stats(policy_timer)
policy_timer.reset()
# Calculate intervention rate
intervention_rate = 0.0
if episode_total_steps > 0:
intervention_rate = episode_intervention_steps / episode_total_steps
# Send episodic reward to the learner
interactions_queue.put(
python_object_to_bytes(
{
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
"Episode intervention": int(episode_intervention),
"Intervention rate": intervention_rate,
**stats,
}
)
)
# Reset intervention counters
sum_reward_episode = 0.0
episode_intervention = False
episode_intervention_steps = 0
episode_total_steps = 0
obs, info = online_env.reset()
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.env.fps - dt_time)
#################################################
# Communication Functions - Group all gRPC/messaging functions #
#################################################
def establish_learner_connection(
stub: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
attempts: int = 30,
):
"""Establish a connection with the learner.
Args:
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
shutdown_event (Event): The event to check if the connection should be established.
attempts (int): The number of attempts to establish the connection.
Returns:
bool: True if the connection is established, False otherwise.
"""
for _ in range(attempts):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down establish_learner_connection")
return False
# Force a connection attempt and check state
try:
logging.info("[ACTOR] Send ready message to Learner")
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
return True
except grpc.RpcError as e:
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
time.sleep(2)
return False
@lru_cache(maxsize=1)
def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
"""
Returns a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
"""
service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}
service_config_json = json.dumps(service_config)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.enable_retries", 1),
("grpc.service_config", service_config_json),
],
)
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
logging.info("[ACTOR] Learner service client created")
return stub, channel
def receive_policy(
cfg: TrainPipelineConfig,
parameters_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
"""Receive parameters from the learner.
Args:
cfg (TrainPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor receive policy process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(use_threads=False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
receive_bytes_in_chunks(
iterator,
parameters_queue,
shutdown_event,
log_prefix="[ACTOR] parameters",
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Received policy loop stopped")
def send_transitions(
cfg: TrainPipelineConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
"""
Sends transitions to the learner.
This function continuously retrieves messages from the queue and processes:
- Transition Data:
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
"""
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor transitions process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming transitions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Transitions process stopped")
def send_interactions(
cfg: TrainPipelineConfig,
interactions_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
"""
Sends interactions to the learner.
This function continuously retrieves messages from the queue and processes:
- Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
"""
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor interactions process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming interactions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=5)
except Empty:
logging.debug("[ACTOR] Transition queue is empty")
continue
yield from send_bytes_in_chunks(
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
)
return hilserl_pb2.Empty()
def interactions_stream(
shutdown_event: Event, # type: ignore
interactions_queue: Queue,
) -> hilserl_pb2.Empty:
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=5)
except Empty:
logging.debug("[ACTOR] Interaction queue is empty")
continue
yield from send_bytes_in_chunks(
message,
hilserl_pb2.InteractionMessage,
log_prefix="[ACTOR] Send interactions",
)
return hilserl_pb2.Empty()
#################################################
# Policy functions #
#################################################
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
if not parameters_queue.empty():
logging.info("[ACTOR] Load new parameters from Learner.")
bytes_state_dict = get_last_item_from_queue(parameters_queue)
state_dict = bytes_to_state_dict(bytes_state_dict)
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict)
#################################################
# Utilities functions #
#################################################
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
"""Send transitions to learner in smaller chunks to avoid network issues.
Args:
transitions: List of transitions to send
message_queue: Queue to send messages to learner
chunk_size: Size of each chunk to send
"""
transition_to_send_to_learner = []
for transition in transitions:
tr = move_transition_to_device(transition=transition, device="cpu")
for key, value in tr["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
transition_to_send_to_learner.append(tr)
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
"""Get the frequency statistics of the policy.
Args:
timer (TimerManager): The timer with collected metrics.
Returns:
dict[str, float]: The frequency statistics of the policy.
"""
stats = {}
if timer.count > 1:
avg_fps = timer.fps_avg
p90_fps = timer.fps_percentile(90)
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
stats = {
"Policy frequency [Hz]": avg_fps,
"Policy frequency 90th-p [Hz]": p90_fps,
}
return stats
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
if policy_fps < cfg.env.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
)
def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency.actor == "threads"
if __name__ == "__main__":
actor_cli()

View File

@@ -1,820 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from contextlib import suppress
from typing import Callable, Optional, Sequence, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.scripts.server.utils import Transition
class BatchTransition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: torch.Tensor
next_state: dict[str, torch.Tensor]
done: torch.Tensor
truncated: torch.Tensor
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
"""
Perform a per-image random crop over a batch of images in a vectorized way.
(Same as shown previously.)
"""
B, C, H, W = images.shape # noqa: N806
crop_h, crop_w = output_size
if crop_h > H or crop_w > W:
raise ValueError(
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
)
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
return cropped
def random_shift(images: torch.Tensor, pad: int = 4):
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
_, _, h, w = images.shape
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
return random_crop_vectorized(images=images, output_size=(h, w))
class ReplayBuffer:
def __init__(
self,
capacity: int,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
optimize_memory: bool = False,
):
"""
Replay buffer for storing transitions.
It will allocate tensors on the specified device, when the first transition is added.
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
and use the `storage_device` flag to store the buffer on a different device.
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
Using "cpu" can help save GPU memory.
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
"""
if capacity <= 0:
raise ValueError("Capacity must be greater than 0.")
self.capacity = capacity
self.device = device
self.storage_device = storage_device
self.position = 0
self.size = 0
self.initialized = False
self.optimize_memory = optimize_memory
# Track episode boundaries for memory optimization
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
# If no state_keys provided, default to an empty list
self.state_keys = state_keys if state_keys is not None else []
self.image_augmentation_function = image_augmentation_function
if image_augmentation_function is None:
base_function = functools.partial(random_shift, pad=4)
self.image_augmentation_function = torch.compile(base_function)
self.use_drq = use_drq
def _initialize_storage(
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Initialize the storage tensors based on the first transition."""
# Determine shapes from the first transition
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
action_shape = action.squeeze(0).shape
# Pre-allocate tensors for storage
self.states = {
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
if not self.optimize_memory:
# Standard approach: store states and next_states separately
self.next_states = {
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
else:
# Memory-optimized approach: don't allocate next_states buffer
# Just create a reference to states for consistent API
self.next_states = self.states # Just a reference for API consistency
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
# Initialize storage for complementary_info
self.has_complementary_info = complementary_info is not None
self.complementary_info_keys = []
self.complementary_info = {}
if self.has_complementary_info:
self.complementary_info_keys = list(complementary_info.keys())
# Pre-allocate tensors for each key in complementary_info
for key, value in complementary_info.items():
if isinstance(value, torch.Tensor):
value_shape = value.squeeze(0).shape
self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device
)
elif isinstance(value, (int, float)):
# Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else:
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
self.initialized = True
def __len__(self):
return self.size
def add(
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
reward: float,
next_state: dict[str, torch.Tensor],
done: bool,
truncated: bool,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
# Initialize storage if this is the first transition
if not self.initialized:
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
# Store the transition in pre-allocated tensors
for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
if not self.optimize_memory:
# Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
self.dones[self.position] = done
self.truncateds[self.position] = truncated
# Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info:
# Store the complementary_info
for key in self.complementary_info_keys:
if key in complementary_info:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int, float)):
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
if not self.initialized:
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
batch_size = min(batch_size, self.size)
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
# Random indices for sampling - create on the same device as storage
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
# Create batched state and next_state
batch_state = {}
batch_next_state = {}
# First pass: load all state tensors to target device
for key in self.states:
batch_state[key] = self.states[key][idx].to(self.device)
if not self.optimize_memory:
# Standard approach - load next_states directly
batch_next_state[key] = self.next_states[key][idx].to(self.device)
else:
# Memory-optimized approach - get next_state from the next index
next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device)
# Apply image augmentation in a batched way if needed
if self.use_drq and image_keys:
# Concatenate all images from state and next_state
all_images = []
for key in image_keys:
all_images.append(batch_state[key])
all_images.append(batch_next_state[key])
# Optimization: Batch all images and apply augmentation once
all_images_tensor = torch.cat(all_images, dim=0)
augmented_images = self.image_augmentation_function(all_images_tensor)
# Split the augmented images back to their sources
for i, key in enumerate(image_keys):
# Calculate offsets for the current image key:
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
# States start at index i*2*batch_size and take up batch_size slots
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
return BatchTransition(
state=batch_state,
action=batch_actions,
reward=batch_rewards,
next_state=batch_next_state,
done=batch_dones,
truncated=batch_truncateds,
complementary_info=batch_complementary_info,
)
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""
Creates an infinite iterator that yields batches of transitions.
Will automatically restart when internal iterator is exhausted.
Args:
batch_size (int): Size of batches to sample
async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True)
queue_size (int): Number of batches to prefetch (default: 2)
Yields:
BatchTransition: Batched transitions
"""
while True: # Create an infinite loop
if async_prefetch:
# Get the standard iterator
iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size)
else:
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
# Yield all items from the iterator
with suppress(StopIteration):
yield from iterator
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
"""
Creates an iterator that prefetches batches in a background thread.
Args:
queue_size (int): Number of batches to prefetch (default: 2)
batch_size (int): Size of batches to sample (default: 128)
Yields:
BatchTransition: Prefetched batch transitions
"""
import queue
import threading
# Use thread-safe queue
data_queue = queue.Queue(maxsize=queue_size)
running = [True] # Use list to allow modification in nested function
def prefetch_worker():
while running[0]:
try:
# Sample data and add to queue
data = self.sample(batch_size)
data_queue.put(data, block=True, timeout=0.5)
except queue.Full:
continue
except Exception as e:
print(f"Prefetch error: {e}")
break
# Start prefetching thread
thread = threading.Thread(target=prefetch_worker, daemon=True)
thread.start()
try:
while running[0]:
try:
yield data_queue.get(block=True, timeout=0.5)
except queue.Empty:
if not thread.is_alive():
break
finally:
# Clean up
running[0] = False
thread.join(timeout=1.0)
def _get_naive_iterator(self, batch_size: int, queue_size: int = 2):
"""
Creates a simple non-threaded iterator that yields batches.
Args:
batch_size (int): Size of batches to sample
queue_size (int): Number of initial batches to prefetch
Yields:
BatchTransition: Batch transitions
"""
import collections
queue = collections.deque()
def enqueue(n):
for _ in range(n):
data = self.sample(batch_size)
queue.append(data)
enqueue(queue_size)
while queue:
yield queue.popleft()
enqueue(1)
@classmethod
def from_lerobot_dataset(
cls,
lerobot_dataset: LeRobotDataset,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
optimize_memory: bool = False,
) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
Args:
lerobot_dataset (LeRobotDataset): The dataset to convert.
device (str): The device for sampling tensors. Defaults to "cuda:0".
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
image_augmentation_function (Optional[Callable]): Function for image augmentation.
If None, uses default random shift with pad=4.
use_drq (bool): Whether to use DrQ image augmentation when sampling.
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
Returns:
ReplayBuffer: The replay buffer with dataset transitions.
"""
if capacity is None:
capacity = len(lerobot_dataset)
if capacity < len(lerobot_dataset):
raise ValueError(
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
)
# Create replay buffer with image augmentation and DrQ settings
replay_buffer = cls(
capacity=capacity,
device=device,
state_keys=state_keys,
image_augmentation_function=image_augmentation_function,
use_drq=use_drq,
storage_device=storage_device,
optimize_memory=optimize_memory,
)
# Convert dataset to transitions
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
# Initialize the buffer with the first transition to set up storage tensors
if list_transition:
first_transition = list_transition[0]
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition["action"].to(device)
# Get complementary info if available
first_complementary_info = None
if (
"complementary_info" in first_transition
and first_transition["complementary_info"] is not None
):
first_complementary_info = {
k: v.to(device) for k, v in first_transition["complementary_info"].items()
}
replay_buffer._initialize_storage(
state=first_state, action=first_action, complementary_info=first_complementary_info
)
# Fill the buffer with all transitions
for data in list_transition:
for k, v in data.items():
if isinstance(v, dict):
for key, tensor in v.items():
v[key] = tensor.to(storage_device)
elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device)
action = data["action"]
replay_buffer.add(
state=data["state"],
action=action,
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
complementary_info=data.get("complementary_info", None),
)
return replay_buffer
def to_lerobot_dataset(
self,
repo_id: str,
fps=1,
root=None,
task_name="from_replay_buffer",
) -> LeRobotDataset:
"""
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object.
"""
if self.size == 0:
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
# Create features dictionary for the dataset
features = {
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
"task_index": {"dtype": "int64", "shape": [1]},
}
# Add "action"
sample_action = self.actions[0]
act_info = guess_feature_info(t=sample_action, name="action")
features["action"] = act_info
# Add "reward" and "done"
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
features["next.done"] = {"dtype": "bool", "shape": (1,)}
# Add state keys
for key in self.states:
sample_val = self.states[key][0]
f_info = guess_feature_info(t=sample_val, name=key)
features[key] = f_info
# Add complementary_info keys if available
if self.has_complementary_info:
for key in self.complementary_info_keys:
sample_val = self.complementary_info[key][0]
if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
sample_val = sample_val.unsqueeze(0)
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
features[f"complementary_info.{key}"] = f_info
# Create an empty LeRobotDataset
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=fps,
root=root,
robot=None, # TODO: (azouitine) Handle robot
robot_type=None,
features=features,
use_videos=True,
)
# Start writing images if needed
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
# Convert transitions into episodes and frames
episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
frame_idx_in_episode = 0
for idx in range(self.size):
actual_idx = (self.position - self.size + idx) % self.capacity
frame_dict = {}
# Fill the data for state keys
for key in self.states:
frame_dict[key] = self.states[key][actual_idx].cpu()
# Fill action, reward, done
frame_dict["action"] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
# Add complementary_info if available
if self.has_complementary_info:
for key in self.complementary_info_keys:
val = self.complementary_info[key][actual_idx]
# Convert tensors to CPU
if isinstance(val, torch.Tensor):
if val.ndim == 0:
val = val.unsqueeze(0)
frame_dict[f"complementary_info.{key}"] = val.cpu()
# Non-tensor values can be used directly
else:
frame_dict[f"complementary_info.{key}"] = val
# Add task field which is required by LeRobotDataset
frame_dict["task"] = task_name
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict)
# Move to next frame
frame_idx_in_episode += 1
# If we reached an episode boundary, call save_episode, reset counters
if self.dones[actual_idx] or self.truncateds[actual_idx]:
lerobot_dataset.save_episode()
episode_index += 1
frame_idx_in_episode = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index=episode_index
)
# Save any remaining frames in the buffer
if lerobot_dataset.episode_buffer["size"] > 0:
lerobot_dataset.save_episode()
lerobot_dataset.stop_image_writer()
return lerobot_dataset
@staticmethod
def _lerobotdataset_to_transitions(
dataset: LeRobotDataset,
state_keys: Optional[Sequence[str]] = None,
) -> list[Transition]:
"""
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
Args:
dataset (LeRobotDataset):
The dataset to convert. Each item in the dataset is expected to have
at least the following keys:
{
"action": ...
"next.reward": ...
"next.done": ...
"episode_index": ...
}
plus whatever your 'state_keys' specify.
state_keys (Optional[Sequence[str]]):
The dataset keys to include in 'state' and 'next_state'. Their names
will be kept as-is in the output transitions. E.g.
["observation.state", "observation.environment_state"].
If None, you must handle or define default keys.
Returns:
transitions (List[Transition]):
A list of Transition dictionaries with the same length as `dataset`.
"""
if state_keys is None:
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
transitions = []
num_frames = len(dataset)
# Check if the dataset has "next.done" key
sample = dataset[0]
has_done_key = "next.done" in sample
# Check for complementary_info keys
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
has_complementary_info = len(complementary_info_keys) > 0
# If not, we need to infer it from episode boundaries
if not has_done_key:
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
for i in tqdm(range(num_frames)):
current_sample = dataset[i]
# ----- 1) Current state -----
current_state: dict[str, torch.Tensor] = {}
for key in state_keys:
val = current_sample[key]
current_state[key] = val.unsqueeze(0) # Add batch dimension
# ----- 2) Action -----
action = current_sample["action"].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
if has_done_key:
done = bool(current_sample["next.done"].item()) # ensure bool
else:
# If this is the last frame or if next frame is in a different episode, mark as done
done = False
if i == num_frames - 1:
done = True
elif i < num_frames - 1:
next_sample = dataset[i + 1]
if next_sample["episode_index"] != current_sample["episode_index"]:
done = True
# TODO: (azouitine) Handle truncation (using the same value as done for now)
truncated = done
# ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state.
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
next_state = current_state # default
if not done and (i < num_frames - 1):
next_sample = dataset[i + 1]
if next_sample["episode_index"] == current_sample["episode_index"]:
# Build next_state from the same keys
next_state_data: dict[str, torch.Tensor] = {}
for key in state_keys:
val = next_sample[key]
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
next_state = next_state_data
# ----- 5) Complementary info (if available) -----
complementary_info = None
if has_complementary_info:
complementary_info = {}
for key in complementary_info_keys:
# Strip the "complementary_info." prefix to get the actual key
clean_key = key[len("complementary_info.") :]
val = current_sample[key]
# Handle tensor and non-tensor values differently
if isinstance(val, torch.Tensor):
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
else:
# TODO: (azouitine) Check if it's necessary to convert to tensor
# For non-tensor values, use directly
complementary_info[clean_key] = val
# ----- Construct the Transition -----
transition = Transition(
state=current_state,
action=action,
reward=reward,
next_state=next_state,
done=done,
truncated=truncated,
complementary_info=complementary_info,
)
transitions.append(transition)
return transitions
# Utility function to guess shapes/dtypes from a tensor
def guess_feature_info(t, name: str):
"""
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
Otherwise default to appropriate dtype for numeric.
"""
shape = tuple(t.shape)
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
if len(shape) == 3 and shape[0] in [1, 3]:
return {
"dtype": "image",
"shape": shape,
}
else:
# Otherwise treat as numeric
return {
"dtype": "float32",
"shape": shape,
}
def concatenate_batch_transitions(
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
# Concatenate state fields
left_batch_transitions["state"] = {
key: torch.cat(
[left_batch_transitions["state"][key], right_batch_transition["state"][key]],
dim=0,
)
for key in left_batch_transitions["state"]
}
# Concatenate basic fields
left_batch_transitions["action"] = torch.cat(
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
)
left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
)
# Concatenate next_state fields
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
dim=0,
)
for key in left_batch_transitions["next_state"]
}
# Concatenate done and truncated fields
left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
)
left_batch_transitions["truncated"] = torch.cat(
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
dim=0,
)
# Handle complementary_info
left_info = left_batch_transitions.get("complementary_info")
right_info = right_batch_transition.get("complementary_info")
# Only process if right_info exists
if right_info is not None:
# Initialize left complementary_info if needed
if left_info is None:
left_batch_transitions["complementary_info"] = right_info
else:
# Concatenate each field
for key in right_info:
if key in left_info:
left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
else:
left_info[key] = right_info[key]
return left_batch_transitions

View File

@@ -1,303 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
from copy import deepcopy
from pathlib import Path
from typing import Dict, Tuple
import cv2
# import torch.nn.functional as F # noqa: N812
import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def select_rect_roi(img):
"""
Allows the user to draw a rectangular ROI on the image.
The user must click and drag to draw the rectangle.
- While dragging, the rectangle is dynamically drawn.
- On mouse button release, the rectangle is fixed.
- Press 'c' to confirm the selection.
- Press 'r' to reset the selection.
- Press ESC to cancel.
Returns:
A tuple (top, left, height, width) representing the rectangular ROI,
or None if no valid ROI is selected.
"""
# Create a working copy of the image
clone = img.copy()
working_img = clone.copy()
roi = None # Will store the final ROI as (top, left, height, width)
drawing = False
index_x, index_y = -1, -1 # Initial click coordinates
def mouse_callback(event, x, y, flags, param):
nonlocal index_x, index_y, drawing, roi, working_img
if event == cv2.EVENT_LBUTTONDOWN:
# Start drawing: record starting coordinates
drawing = True
index_x, index_y = x, y
elif event == cv2.EVENT_MOUSEMOVE:
if drawing:
# Compute the top-left and bottom-right corners regardless of drag direction
top = min(index_y, y)
left = min(index_x, x)
bottom = max(index_y, y)
right = max(index_x, x)
# Show a temporary image with the current rectangle drawn
temp = working_img.copy()
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", temp)
elif event == cv2.EVENT_LBUTTONUP:
# Finish drawing
drawing = False
top = min(index_y, y)
left = min(index_x, x)
bottom = max(index_y, y)
right = max(index_x, x)
height = bottom - top
width = right - left
roi = (top, left, height, width) # (top, left, height, width)
# Draw the final rectangle on the working image and display it
working_img = clone.copy()
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", working_img)
# Create the window and set the callback
cv2.namedWindow("Select ROI")
cv2.setMouseCallback("Select ROI", mouse_callback)
cv2.imshow("Select ROI", working_img)
print("Instructions for ROI selection:")
print(" - Click and drag to draw a rectangular ROI.")
print(" - Press 'c' to confirm the selection.")
print(" - Press 'r' to reset and draw again.")
print(" - Press ESC to cancel the selection.")
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
while True:
key = cv2.waitKey(1) & 0xFF
# Confirm ROI if one has been drawn
if key == ord("c") and roi is not None:
break
# Reset: clear the ROI and restore the original image
elif key == ord("r"):
working_img = clone.copy()
roi = None
cv2.imshow("Select ROI", working_img)
# Cancel selection for this image
elif key == 27: # ESC key
roi = None
break
cv2.destroyWindow("Select ROI")
return roi
def select_square_roi_for_images(images: dict) -> dict:
"""
For each image in the provided dictionary, open a window to allow the user
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
(top, left, height, width) representing the ROI.
Parameters:
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
Returns:
dict: Mapping of image keys to the selected rectangular ROI.
"""
selected_rois = {}
for key, img in images.items():
if img is None:
print(f"Image for key '{key}' is None, skipping.")
continue
print(f"\nSelect rectangular ROI for image with key: '{key}'")
roi = select_rect_roi(img)
if roi is None:
print(f"No valid ROI selected for '{key}'.")
else:
selected_rois[key] = roi
print(f"ROI for '{key}': {roi}")
return selected_rois
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
"""
Find the first row in the dataset and extract the image in order to be used for the crop.
"""
row = dataset[0]
image_dict = {}
for k in row:
if "image" in k:
image_dict[k] = deepcopy(row[k])
return image_dict
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset: LeRobotDataset,
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
new_repo_id: str,
new_dataset_root: str,
resize_size: Tuple[int, int] = (128, 128),
push_to_hub: bool = False,
) -> LeRobotDataset:
"""
Converts an existing LeRobotDataset by iterating over its episodes and frames,
applying cropping and resizing to image observations, and saving a new dataset
with the transformed data.
Args:
original_dataset (LeRobotDataset): The source dataset.
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
A dictionary mapping observation keys to crop parameters (top, left, height, width).
new_repo_id (str): Repository id for the new dataset.
new_dataset_root (str): The root directory where the new dataset will be written.
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
Defaults to (128, 128).
Returns:
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
and resized.
"""
# 1. Create a new (empty) LeRobotDataset for writing.
new_dataset = LeRobotDataset.create(
repo_id=new_repo_id,
fps=original_dataset.fps,
root=new_dataset_root,
robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"],
use_videos=len(original_dataset.meta.video_keys) > 0,
)
# Update the metadata for every image key that will be cropped:
# (Here we simply set the shape to be the final resize_size.)
for key in crop_params_dict:
if key in new_dataset.meta.info["features"]:
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
prev_episode_index = 0
for frame_idx in tqdm(range(len(original_dataset))):
frame = original_dataset[frame_idx]
# Create a copy of the frame to add to the new dataset
new_frame = {}
for key, value in frame.items():
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index"):
continue
if key in ("next.done", "next.reward"):
# if not isinstance(value, str) and len(value.shape) == 0:
value = value.unsqueeze(0)
if key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
# Apply crop then resize.
cropped = F.crop(value, top, left, height, width)
value = F.resize(cropped, resize_size)
value = value.clamp(0, 1)
new_frame[key] = value
new_dataset.add_frame(new_frame)
if frame["episode_index"].item() != prev_episode_index:
# Save the episode
new_dataset.save_episode()
prev_episode_index = frame["episode_index"].item()
if push_to_hub:
new_dataset.push_to_hub()
return new_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser.add_argument(
"--repo-id",
type=str,
default="lerobot",
help="The repository id of the LeRobot dataset to process.",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="The root directory of the LeRobot dataset.",
)
parser.add_argument(
"--crop-params-path",
type=str,
default=None,
help="The path to the JSON file containing the ROIs.",
)
parser.add_argument(
"--push-to-hub",
type=bool,
default=False,
help="Whether to push the new dataset to the hub.",
)
args = parser.parse_args()
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path) as f:
rois = json.load(f)
# Print the selected rectangular ROIs
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
for key, roi in rois.items():
print(f"{key}: {roi}")
new_repo_id = args.repo_id + "_cropped_resized"
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=new_repo_id,
new_dataset_root=new_dataset_root,
resize_size=(128, 128),
push_to_hub=args.push_to_hub,
)
meta_dir = new_dataset_root / "meta"
meta_dir.mkdir(exist_ok=True)
with open(meta_dir / "crop_params.json", "w") as f:
json.dump(rois, f, indent=4)

View File

@@ -1,802 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import sys
import time
import numpy as np
import torch
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.server.kinematics import RobotKinematics
class InputController:
"""Base class for input controllers that generate motion deltas."""
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
"""
Initialize the controller.
Args:
x_step_size: Base movement step size in meters
y_step_size: Base movement step size in meters
z_step_size: Base movement step size in meters
"""
self.x_step_size = x_step_size
self.y_step_size = y_step_size
self.z_step_size = z_step_size
self.running = True
self.episode_end_status = None # None, "success", or "failure"
self.intervention_flag = False
self.open_gripper_command = False
self.close_gripper_command = False
def start(self):
"""Start the controller and initialize resources."""
pass
def stop(self):
"""Stop the controller and release resources."""
pass
def get_deltas(self):
"""Get the current movement deltas (dx, dy, dz) in meters."""
return 0.0, 0.0, 0.0
def should_quit(self):
"""Return True if the user has requested to quit."""
return not self.running
def update(self):
"""Update controller state - call this once per frame."""
pass
def __enter__(self):
"""Support for use in 'with' statements."""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Ensure resources are released when exiting 'with' block."""
self.stop()
def get_episode_end_status(self):
"""
Get the current episode end status.
Returns:
None if episode should continue, "success" or "failure" otherwise
"""
status = self.episode_end_status
self.episode_end_status = None # Reset after reading
return status
def should_intervene(self):
"""Return True if intervention flag was set."""
return self.intervention_flag
def gripper_command(self):
"""Return the current gripper command."""
if self.open_gripper_command == self.close_gripper_command:
return "no-op"
elif self.open_gripper_command:
return "open"
elif self.close_gripper_command:
return "close"
class KeyboardController(InputController):
"""Generate motion deltas from keyboard input."""
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
super().__init__(x_step_size, y_step_size, z_step_size)
self.key_states = {
"forward_x": False,
"backward_x": False,
"forward_y": False,
"backward_y": False,
"forward_z": False,
"backward_z": False,
"quit": False,
"success": False,
"failure": False,
}
self.listener = None
def start(self):
"""Start the keyboard listener."""
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.up:
self.key_states["forward_x"] = True
elif key == keyboard.Key.down:
self.key_states["backward_x"] = True
elif key == keyboard.Key.left:
self.key_states["forward_y"] = True
elif key == keyboard.Key.right:
self.key_states["backward_y"] = True
elif key == keyboard.Key.shift:
self.key_states["backward_z"] = True
elif key == keyboard.Key.shift_r:
self.key_states["forward_z"] = True
elif key == keyboard.Key.esc:
self.key_states["quit"] = True
self.running = False
return False
elif key == keyboard.Key.enter:
self.key_states["success"] = True
self.episode_end_status = "success"
elif key == keyboard.Key.backspace:
self.key_states["failure"] = True
self.episode_end_status = "failure"
except AttributeError:
pass
def on_release(key):
try:
if key == keyboard.Key.up:
self.key_states["forward_x"] = False
elif key == keyboard.Key.down:
self.key_states["backward_x"] = False
elif key == keyboard.Key.left:
self.key_states["forward_y"] = False
elif key == keyboard.Key.right:
self.key_states["backward_y"] = False
elif key == keyboard.Key.shift:
self.key_states["backward_z"] = False
elif key == keyboard.Key.shift_r:
self.key_states["forward_z"] = False
elif key == keyboard.Key.enter:
self.key_states["success"] = False
elif key == keyboard.Key.backspace:
self.key_states["failure"] = False
except AttributeError:
pass
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
self.listener.start()
print("Keyboard controls:")
print(" Arrow keys: Move in X-Y plane")
print(" Shift and Shift_R: Move in Z axis")
print(" Enter: End episode with SUCCESS")
print(" Backspace: End episode with FAILURE")
print(" ESC: Exit")
def stop(self):
"""Stop the keyboard listener."""
if self.listener and self.listener.is_alive():
self.listener.stop()
def get_deltas(self):
"""Get the current movement deltas from keyboard state."""
delta_x = delta_y = delta_z = 0.0
if self.key_states["forward_x"]:
delta_x += self.x_step_size
if self.key_states["backward_x"]:
delta_x -= self.x_step_size
if self.key_states["forward_y"]:
delta_y += self.y_step_size
if self.key_states["backward_y"]:
delta_y -= self.y_step_size
if self.key_states["forward_z"]:
delta_z += self.z_step_size
if self.key_states["backward_z"]:
delta_z -= self.z_step_size
return delta_x, delta_y, delta_z
def should_quit(self):
"""Return True if ESC was pressed."""
return self.key_states["quit"]
def should_save(self):
"""Return True if Enter was pressed (save episode)."""
return self.key_states["success"] or self.key_states["failure"]
class GamepadController(InputController):
"""Generate motion deltas from gamepad input."""
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1):
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.joystick = None
self.intervention_flag = False
def start(self):
"""Initialize pygame and the gamepad."""
import pygame
pygame.init()
pygame.joystick.init()
if pygame.joystick.get_count() == 0:
logging.error("No gamepad detected. Please connect a gamepad and try again.")
self.running = False
return
self.joystick = pygame.joystick.Joystick(0)
self.joystick.init()
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
print("Gamepad controls:")
print(" Left analog stick: Move in X-Y plane")
print(" Right analog stick (vertical): Move in Z axis")
print(" B/Circle button: Exit")
print(" Y/Triangle button: End episode with SUCCESS")
print(" A/Cross button: End episode with FAILURE")
print(" X/Square button: Rerecord episode")
def stop(self):
"""Clean up pygame resources."""
import pygame
if pygame.joystick.get_init():
if self.joystick:
self.joystick.quit()
pygame.joystick.quit()
pygame.quit()
def update(self):
"""Process pygame events to get fresh gamepad readings."""
import pygame
for event in pygame.event.get():
if event.type == pygame.JOYBUTTONDOWN:
if event.button == 3:
self.episode_end_status = "success"
# A button (1) for failure
elif event.button == 1:
self.episode_end_status = "failure"
# X button (0) for rerecord
elif event.button == 0:
self.episode_end_status = "rerecord_episode"
# RB button (6) for closing gripper
elif event.button == 6:
self.close_gripper_command = True
# LT button (7) for opening gripper
elif event.button == 7:
self.open_gripper_command = True
# Reset episode status on button release
elif event.type == pygame.JOYBUTTONUP:
if event.button in [0, 2, 3]:
self.episode_end_status = None
elif event.button == 6:
self.close_gripper_command = False
elif event.button == 7:
self.open_gripper_command = False
# Check for RB button (typically button 5) for intervention flag
if self.joystick.get_button(5):
self.intervention_flag = True
else:
self.intervention_flag = False
def get_deltas(self):
"""Get the current movement deltas from gamepad state."""
import pygame
try:
# Read joystick axes
# Left stick X and Y (typically axes 0 and 1)
x_input = self.joystick.get_axis(0) # Left/Right
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
# Right stick Y (typically axis 3 or 4)
z_input = self.joystick.get_axis(3) # Up/Down for Z
# Apply deadzone to avoid drift
x_input = 0 if abs(x_input) < self.deadzone else x_input
y_input = 0 if abs(y_input) < self.deadzone else y_input
z_input = 0 if abs(z_input) < self.deadzone else z_input
# Calculate deltas (note: may need to invert axes depending on controller)
delta_x = -y_input * self.y_step_size # Forward/backward
delta_y = -x_input * self.x_step_size # Left/right
delta_z = -z_input * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
except pygame.error:
logging.error("Error reading gamepad. Is it still connected?")
return 0.0, 0.0, 0.0
class GamepadControllerHID(InputController):
"""Generate motion deltas from gamepad input using HIDAPI."""
def __init__(
self,
x_step_size=0.01,
y_step_size=0.01,
z_step_size=0.01,
deadzone=0.1,
vendor_id=0x046D,
product_id=0xC219,
):
"""
Initialize the HID gamepad controller.
Args:
step_size: Base movement step size in meters
z_scale: Scaling factor for Z-axis movement
deadzone: Joystick deadzone to prevent drift
vendor_id: USB vendor ID of the gamepad (default: Logitech)
product_id: USB product ID of the gamepad (default: RumblePad 2)
"""
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.vendor_id = vendor_id
self.product_id = product_id
self.device = None
self.device_info = None
# Movement values (normalized from -1.0 to 1.0)
self.left_x = 0.0
self.left_y = 0.0
self.right_x = 0.0
self.right_y = 0.0
# Button states
self.buttons = {}
self.quit_requested = False
self.save_requested = False
def find_device(self):
"""Look for the gamepad device by vendor and product ID."""
import hid
devices = hid.enumerate()
for device in devices:
if device["vendor_id"] == self.vendor_id and device["product_id"] == self.product_id:
logging.info(f"Found gamepad: {device.get('product_string', 'Unknown')}")
return device
logging.error(
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and product ID 0x{self.product_id:04X} found"
)
return None
def start(self):
"""Connect to the gamepad using HIDAPI."""
import hid
self.device_info = self.find_device()
if not self.device_info:
self.running = False
return
try:
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
self.device = hid.device()
self.device.open_path(self.device_info["path"])
self.device.set_nonblocking(1)
manufacturer = self.device.get_manufacturer_string()
product = self.device.get_product_string()
logging.info(f"Connected to {manufacturer} {product}")
logging.info("Gamepad controls (HID mode):")
logging.info(" Left analog stick: Move in X-Y plane")
logging.info(" Right analog stick: Move in Z axis (vertical)")
logging.info(" Button 1/B/Circle: Exit")
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
logging.info(" Button 3/X/Square: End episode with FAILURE")
except OSError as e:
logging.error(f"Error opening gamepad: {e}")
logging.error("You might need to run this with sudo/admin privileges on some systems")
self.running = False
def stop(self):
"""Close the HID device connection."""
if self.device:
self.device.close()
self.device = None
def update(self):
"""
Read and process the latest gamepad data.
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
"""
for _ in range(10):
self._update()
def _update(self):
"""Read and process the latest gamepad data."""
if not self.device or not self.running:
return
try:
# Read data from the gamepad
data = self.device.read(64)
# Interpret gamepad data - this will vary by controller model
# These offsets are for the Logitech RumblePad 2
if data and len(data) >= 8:
# Normalize joystick values from 0-255 to -1.0-1.0
self.left_x = (data[1] - 128) / 128.0
self.left_y = (data[2] - 128) / 128.0
self.right_x = (data[3] - 128) / 128.0
self.right_y = (data[4] - 128) / 128.0
# Apply deadzone
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
# Parse button states (byte 5 in the Logitech RumblePad 2)
buttons = data[5]
# Check if RB is pressed then the intervention flag should be set
self.intervention_flag = data[6] in [2, 6, 10, 14]
# Check if RT is pressed
self.open_gripper_command = data[6] in [8, 10, 12]
# Check if LT is pressed
self.close_gripper_command = data[6] in [4, 6, 12]
# Check if Y/Triangle button (bit 7) is pressed for saving
# Check if X/Square button (bit 5) is pressed for failure
# Check if A/Cross button (bit 4) is pressed for rerecording
if buttons & 1 << 7:
self.episode_end_status = "success"
elif buttons & 1 << 5:
self.episode_end_status = "failure"
elif buttons & 1 << 4:
self.episode_end_status = "rerecord_episode"
else:
self.episode_end_status = None
except OSError as e:
logging.error(f"Error reading from gamepad: {e}")
def get_deltas(self):
"""Get the current movement deltas from gamepad state."""
# Calculate deltas - invert as needed based on controller orientation
delta_x = -self.left_y * self.x_step_size # Forward/backward
delta_y = -self.left_x * self.y_step_size # Left/right
delta_z = -self.right_y * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
def should_quit(self):
"""Return True if quit button was pressed."""
return self.quit_requested
def should_save(self):
"""Return True if save button was pressed."""
return self.save_requested
def test_forward_kinematics(robot, fps=10):
logging.info("Testing Forward Kinematics")
timestep = time.perf_counter()
kinematics = RobotKinematics(robot.robot_type)
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
robot.teleop_step()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = kinematics.fk_gripper_tip(joint_positions)
logging.info(f"EE Position: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def test_inverse_kinematics(robot, fps=10):
logging.info("Testing Inverse Kinematics")
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
desired_ee_pos = ee_pos
target_joint_state = RobotKinematics.ik(joint_positions, desired_ee_pos, position_only=True)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Target Joint State: {target_joint_state}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Inverse Kinematics")
kinematics = RobotKinematics(robot.robot_type)
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = kinematics.fk_gripper_tip(joint_positions)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = leader_ee
target_joint_state = kinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Delta End-Effector Control")
timestep = time.perf_counter()
# Initial position capture
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
kinematics = RobotKinematics(robot.robot_type)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = np.diag(np.ones(4))
joint_positions = robot.follower_arms["main"].read("Present_Position")
fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions)
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
# Get leader state for teleoperation
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
# Get current state
# obs = robot.capture_observation()
# joint_positions = obs["observation.state"].cpu().numpy()
joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Calculate delta between leader and follower end-effectors
# Scaling factor can be adjusted for sensitivity
scaling_factor = 1.0
ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05)
# Apply delta to current position
desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0
desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0
desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3]
# Compute joint targets via inverse kinematics
target_joint_state = kinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
)
initial_leader_ee = leader_ee.copy()
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
# Logging
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, fk_func=None):
"""
Control a robot using delta end-effector movements from any input controller.
Args:
robot: Robot instance to control
controller: InputController instance (keyboard, gamepad, etc.)
fps: Control frequency in Hz
bounds: Optional position limits
fk_func: Forward kinematics function to use
"""
if fk_func is None:
fk_func = RobotKinematics.fk_gripper_tip
logging.info(f"Testing Delta End-Effector Control with {controller.__class__.__name__}")
# Initial position capture
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
kinematics = RobotKinematics(robot.robot_type)
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Initialize desired position with current position
desired_ee_pos = np.eye(4) # Identity matrix
timestep = time.perf_counter()
with controller:
while not controller.should_quit() and time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
# Process input events
controller.update()
# Get current robot state
joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas()
# Update desired position
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z
# Apply bounds if provided
if bounds is not None:
desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"])
# Only send commands if there's actual movement
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
# Compute joint targets via inverse kinematics
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_gym_env(env, controller, fps: int = 30):
"""
Control a robot through a gym environment using keyboard inputs.
Args:
env: A gym environment created with make_robot_env
fps: Target control frequency
"""
logging.info("Testing Keyboard Control of Gym Environment")
print("Keyboard controls:")
print(" Arrow keys: Move in X-Y plane")
print(" Shift and Shift_R: Move in Z axis")
print(" ESC: Exit")
# Reset the environment to get initial observation
obs, info = env.reset()
try:
with controller:
while not controller.should_quit():
loop_start_time = time.perf_counter()
# Process input events
controller.update()
# Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas()
# Create the action vector
action = np.array([delta_x, delta_y, delta_z])
# Skip if no movement
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
# Step the environment - pass action as a tensor with intervention flag
action_tensor = torch.from_numpy(action.astype(np.float32))
obs, reward, terminated, truncated, info = env.step((action_tensor, False))
# Log information
logging.info(f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]")
logging.info(f"Reward: {reward}")
# Reset if episode ended
if terminated or truncated:
logging.info("Episode ended, resetting environment")
obs, info = env.reset()
# Maintain target frame rate
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
finally:
# Close the environment
env.close()
if __name__ == "__main__":
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvTransformConfig, HILSerlRobotEnvConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import make_robot_env
init_logging()
parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument(
"--mode",
type=str,
default="keyboard",
choices=[
"keyboard",
"gamepad",
"keyboard_gym",
"gamepad_gym",
"leader_delta",
"leader",
],
help="Control mode to use",
)
parser.add_argument(
"--robot-type",
type=str,
default="so100",
help="Robot type (so100, koch, aloha, etc.)",
)
args = parser.parse_args()
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
robot = make_robot_from_config(robot_config)
if not robot.is_connected:
robot.connect()
# Example bounds
bounds = {
"max": np.array([0.32170487, 0.201285, 0.10273342]),
"min": np.array([0.16631757, -0.08237468, 0.03364977]),
}
try:
# Determine controller type based on mode prefix
controller = None
if args.mode.startswith("keyboard"):
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
elif args.mode.startswith("gamepad"):
if sys.platform == "darwin":
controller = GamepadControllerHID(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
else:
controller = GamepadController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
# Handle mode categories
if args.mode in ["keyboard", "gamepad"]:
# Direct robot control modes
teleoperate_delta_inverse_kinematics(robot, controller, bounds=bounds, fps=10)
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvTransformConfig())
cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
)
cfg.wrapper.ee_action_space_params.use_gamepad = False
cfg.device = "cpu"
env = make_robot_env(cfg, robot)
teleoperate_gym_env(env, controller, fps=cfg.fps)
elif args.mode == "leader_delta":
# Leader-follower modes don't use controllers
teleoperate_delta_inverse_kinematics_with_leader(robot)
elif args.mode == "leader":
teleoperate_inverse_kinematics_with_leader(robot)
finally:
if robot.is_connected:
robot.disconnect()

View File

@@ -1,135 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import time
import cv2
import numpy as np
from lerobot.common.robot_devices.control_utils import is_headless
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
def find_joint_bounds(
robot,
control_time_s=30,
display_cameras=False,
):
if not robot.is_connected:
robot.connect()
start_episode_t = time.perf_counter()
pos_list = []
while True:
observation, action = robot.teleop_step(record_data=True)
# Wait for 5 seconds to stabilize the robot initial position
if time.perf_counter() - start_episode_t < 5:
continue
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s:
max = np.max(np.stack(pos_list), 0)
min = np.min(np.stack(pos_list), 0)
print(f"Max angle position per joint {max}")
print(f"Min angle position per joint {min}")
break
def find_ee_bounds(
robot,
control_time_s=30,
display_cameras=False,
):
if not robot.is_connected:
robot.connect()
start_episode_t = time.perf_counter()
ee_list = []
while True:
observation, action = robot.teleop_step(record_data=True)
# Wait for 5 seconds to stabilize the robot initial position
if time.perf_counter() - start_episode_t < 5:
continue
kinematics = RobotKinematics(robot.robot_type)
joint_positions = robot.follower_arms["main"].read("Present_Position")
print(f"Joint positions: {joint_positions}")
ee_list.append(kinematics.fk_gripper_tip(joint_positions)[:3, 3])
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s:
max = np.max(np.stack(ee_list), 0)
min = np.min(np.stack(ee_list), 0)
print(f"Max ee position {max}")
print(f"Min ee position {min}")
break
if __name__ == "__main__":
# Create argparse for script-specific arguments
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
parser.add_argument(
"--mode",
type=str,
default="joint",
choices=["joint", "ee"],
help="Mode to run the script in. Can be 'joint' or 'ee'.",
)
parser.add_argument(
"--control-time-s",
type=int,
default=30,
help="Time step to use for control.",
)
parser.add_argument(
"--robot-type",
type=str,
default="so100",
help="Robot type (so100, koch, aloha, etc.)",
)
# Only parse known args, leaving robot config args for Hydra if used
args = parser.parse_args()
# Create robot with the appropriate config
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
robot = make_robot_from_config(robot_config)
if args.mode == "joint":
find_joint_bounds(robot, args.control_time_s)
elif args.mode == "ee":
find_ee_bounds(robot, args.control_time_s)
if robot.is_connected:
robot.disconnect()

File diff suppressed because it is too large Load Diff

View File

@@ -1,54 +0,0 @@
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package hil_serl;
// LearnerService: the Actor calls this to push transitions.
// The Learner implements this service.
service LearnerService {
// Actor -> Learner to store transitions
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
rpc StreamParameters(Empty) returns (stream Parameters);
rpc SendTransitions(stream Transition) returns (Empty);
rpc SendInteractions(stream InteractionMessage) returns (Empty);
rpc Ready(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Transition {
TransferState transfer_state = 1;
bytes data = 2;
}
message Parameters {
TransferState transfer_state = 1;
bytes data = 2;
}
message InteractionMessage {
TransferState transfer_state = 1;
bytes data = 2;
}
message Empty {}

View File

@@ -1,46 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: hilserl.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'hilserl.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=275
_globals['_TRANSFERSTATE']._serialized_end=371
_globals['_TRANSITION']._serialized_start=27
_globals['_TRANSITION']._serialized_end=102
_globals['_PARAMETERS']._serialized_start=104
_globals['_PARAMETERS']._serialized_end=179
_globals['_INTERACTIONMESSAGE']._serialized_start=181
_globals['_INTERACTIONMESSAGE']._serialized_end=264
_globals['_EMPTY']._serialized_start=266
_globals['_EMPTY']._serialized_end=273
_globals['_LEARNERSERVICE']._serialized_start=374
_globals['_LEARNERSERVICE']._serialized_end=696
# @@protoc_insertion_point(module_scope)

View File

@@ -1,276 +0,0 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import hilserl_pb2 as hilserl__pb2
GRPC_GENERATED_VERSION = '1.70.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in hilserl_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class LearnerServiceStub(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendInteractionMessage = channel.unary_unary(
'/hil_serl.LearnerService/SendInteractionMessage',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.StreamParameters = channel.unary_stream(
'/hil_serl.LearnerService/StreamParameters',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Parameters.FromString,
_registered_method=True)
self.SendTransitions = channel.stream_unary(
'/hil_serl.LearnerService/SendTransitions',
request_serializer=hilserl__pb2.Transition.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.SendInteractions = channel.stream_unary(
'/hil_serl.LearnerService/SendInteractions',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/hil_serl.LearnerService/Ready',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
class LearnerServiceServicer(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
def SendInteractionMessage(self, request, context):
"""Actor -> Learner to store transitions
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StreamParameters(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendTransitions(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendInteractions(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_LearnerServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
servicer.SendInteractionMessage,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'StreamParameters': grpc.unary_stream_rpc_method_handler(
servicer.StreamParameters,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Parameters.SerializeToString,
),
'SendTransitions': grpc.stream_unary_rpc_method_handler(
servicer.SendTransitions,
request_deserializer=hilserl__pb2.Transition.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'SendInteractions': grpc.stream_unary_rpc_method_handler(
servicer.SendInteractions,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'hil_serl.LearnerService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class LearnerService(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
@staticmethod
def SendInteractionMessage(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/SendInteractionMessage',
hilserl__pb2.InteractionMessage.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def StreamParameters(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/hil_serl.LearnerService/StreamParameters',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.Parameters.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendTransitions(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.LearnerService/SendTransitions',
hilserl__pb2.Transition.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendInteractions(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.LearnerService/SendInteractions',
hilserl__pb2.InteractionMessage.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/Ready',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -1,546 +0,0 @@
# ruff: noqa: N806, N815, N803
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy.spatial.transform import Rotation
def skew_symmetric(w):
"""Creates the skew-symmetric matrix from a 3D vector."""
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
def rodrigues_rotation(w, theta):
"""Computes the rotation matrix using Rodrigues' formula."""
w_hat = skew_symmetric(w)
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
def screw_axis_to_transform(S, theta):
"""Converts a screw axis to a 4x4 transformation matrix."""
S_w = S[:3]
S_v = S[3:]
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
T = np.eye(4)
T[:3, 3] = S_v * theta
elif np.linalg.norm(S_w) == 1: # Rotation and translation
w_hat = skew_symmetric(S_w)
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v
T = np.eye(4)
T[:3, :3] = R
T[:3, 3] = t
else:
raise ValueError("Invalid screw axis parameters")
return T
def pose_difference_se3(pose1, pose2):
"""
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, combining rotation (SO(3)) and translation.
Each 4x4 matrix has the following structure, a 3x3 rotation matrix in the top-left and a 3x1 translation vector in the top-right:
[R11 R12 R13 tx]
[R21 R22 R23 ty]
[R31 R32 R33 tz]
[ 0 0 0 1]
where Rij is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector.
pose1 - pose2
Args:
pose1: A 4x4 numpy array representing the first pose.
pose2: A 4x4 numpy array representing the second pose.
Returns:
A tuple (translation_diff, rotation_diff) where:
- translation_diff is a 3x1 numpy array representing the translational difference.
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
"""
# Extract rotation matrices from poses
R1 = pose1[:3, :3]
R2 = pose2[:3, :3]
# Calculate translational difference
translation_diff = pose1[:3, 3] - pose2[:3, 3]
# Calculate rotational difference using scipy's Rotation library
R_diff = Rotation.from_matrix(R1 @ R2.T)
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
return np.concatenate([translation_diff, rotation_diff])
def se3_error(target_pose, current_pose):
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
R_target = target_pose[:3, :3]
R_current = current_pose[:3, :3]
R_error = R_target @ R_current.T
rot_error = Rotation.from_matrix(R_error).as_rotvec()
return np.concatenate([pos_error, rot_error])
class RobotKinematics:
"""Robot kinematics class supporting multiple robot models."""
# Robot measurements dictionary
ROBOT_MEASUREMENTS = {
"koch": {
"gripper": [0.239, -0.001, 0.024],
"wrist": [0.209, 0, 0.024],
"forearm": [0.108, 0, 0.02],
"humerus": [0, 0, 0.036],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"so100": {
"gripper": [0.320, 0, 0.050],
"wrist": [0.278, 0, 0.050],
"forearm": [0.143, 0, 0.044],
"humerus": [0.031, 0, 0.072],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"moss": {
"gripper": [0.246, 0.013, 0.111],
"wrist": [0.245, 0.002, 0.064],
"forearm": [0.122, 0, 0.064],
"humerus": [0.001, 0.001, 0.063],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
}
def __init__(self, robot_type="so100"):
"""Initialize kinematics for the specified robot type.
Args:
robot_type: String specifying the robot model ("koch", "so100", or "moss")
"""
if robot_type not in self.ROBOT_MEASUREMENTS:
raise ValueError(
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
)
self.robot_type = robot_type
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
# Initialize all transformation matrices and screw axes
self._setup_transforms()
def _create_translation_matrix(self, x=0, y=0, z=0):
"""Create a 4x4 translation matrix."""
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
def _setup_transforms(self):
"""Setup all transformation matrices and screw axes for the robot."""
# Set up rotation matrices (constant across robot types)
# Gripper orientation
self.gripper_X0 = np.array(
[
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, -1, 0, 0],
[0, 0, 0, 1],
]
)
# Wrist orientation
self.wrist_X0 = np.array(
[
[0, -1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
)
# Base orientation
self.base_X0 = np.array(
[
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
]
)
# Gripper
# Screw axis of gripper frame wrt base frame
self.S_BG = np.array(
[
1,
0,
0,
0,
self.measurements["gripper"][2],
-self.measurements["gripper"][1],
]
)
# Gripper origin to centroid transform
self.X_GoGc = self._create_translation_matrix(x=0.07)
# Gripper origin to tip transform
self.X_GoGt = self._create_translation_matrix(x=0.12)
# 0-position gripper frame pose wrt base
self.X_BoGo = self._create_translation_matrix(
x=self.measurements["gripper"][0],
y=self.measurements["gripper"][1],
z=self.measurements["gripper"][2],
)
# Wrist
# Screw axis of wrist frame wrt base frame
self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]])
# 0-position origin to centroid transform
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
# 0-position wrist frame pose wrt base
self.X_BR = self._create_translation_matrix(
x=self.measurements["wrist"][0],
y=self.measurements["wrist"][1],
z=self.measurements["wrist"][2],
)
# Forearm
# Screw axis of forearm frame wrt base frame
self.S_BF = np.array(
[
0,
1,
0,
-self.measurements["forearm"][2],
0,
self.measurements["forearm"][0],
]
)
# Forearm origin + centroid transform
self.X_FoFc = self._create_translation_matrix(x=0.036) # spellchecker:disable-line
# 0-position forearm frame pose wrt base
self.X_BF = self._create_translation_matrix(
x=self.measurements["forearm"][0],
y=self.measurements["forearm"][1],
z=self.measurements["forearm"][2],
)
# Humerus
# Screw axis of humerus frame wrt base frame
self.S_BH = np.array(
[
0,
-1,
0,
self.measurements["humerus"][2],
0,
-self.measurements["humerus"][0],
]
)
# Humerus origin to centroid transform
self.X_HoHc = self._create_translation_matrix(x=0.0475)
# 0-position humerus frame pose wrt base
self.X_BH = self._create_translation_matrix(
x=self.measurements["humerus"][0],
y=self.measurements["humerus"][1],
z=self.measurements["humerus"][2],
)
# Shoulder
# Screw axis of shoulder frame wrt Base frame
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
# Shoulder origin to centroid transform
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
# 0-position shoulder frame pose wrt base
self.X_BS = self._create_translation_matrix(
x=self.measurements["shoulder"][0],
y=self.measurements["shoulder"][1],
z=self.measurements["shoulder"][2],
)
# Base
# Base origin to centroid transform
self.X_BoBc = self._create_translation_matrix(y=0.015)
# World to base transform
self.X_WoBo = self._create_translation_matrix(
x=self.measurements["base"][0],
y=self.measurements["base"][1],
z=self.measurements["base"][2],
)
# Pre-compute gripper post-multiplication matrix
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
def fk_base(self):
"""Forward kinematics for the base frame."""
return self.X_WoBo @ self.X_BoBc @ self.base_X0
def fk_shoulder(self, robot_pos_deg):
"""Forward kinematics for the shoulder frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS
def fk_humerus(self, robot_pos_deg):
"""Forward kinematics for the humerus frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ self.X_HoHc
@ self.X_BH
)
def fk_forearm(self, robot_pos_deg):
"""Forward kinematics for the forearm frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ self.X_FoFc # spellchecker:disable-line
@ self.X_BF
)
def fk_wrist(self, robot_pos_deg):
"""Forward kinematics for the wrist frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ self.X_RoRc
@ self.X_BR
@ self.wrist_X0
)
def fk_gripper(self, robot_pos_deg):
"""Forward kinematics for the gripper frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
@ self._fk_gripper_post
)
def fk_gripper_tip(self, robot_pos_deg):
"""Forward kinematics for the gripper tip frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
@ self.X_GoGt
@ self.X_BoGo
@ self.gripper_X0
)
def compute_jacobian(self, robot_pos_deg, fk_func=None):
"""Finite differences to compute the Jacobian.
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
if fk_func is None:
fk_func = self.fk_gripper
eps = 1e-8
jac = np.zeros(shape=(6, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
pose_difference_se3(
fk_func(robot_pos_deg[:-1] + delta),
fk_func(robot_pos_deg[:-1] - delta),
)
/ eps
)
jac[:, el_ix] = Sdot
return jac
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
"""Finite differences to compute the positional Jacobian.
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
if fk_func is None:
fk_func = self.fk_gripper
eps = 1e-8
jac = np.zeros(shape=(3, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
) / eps
jac[:, el_ix] = Sdot
return jac
def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None):
"""Inverse kinematics using gradient descent.
Args:
current_joint_state: Initial joint positions in degrees
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
position_only: If True, only match end-effector position, not orientation
fk_func: Forward kinematics function to use (defaults to fk_gripper)
Returns:
Joint positions in degrees that achieve the desired end-effector pose
"""
if fk_func is None:
fk_func = self.fk_gripper
# Do gradient descent.
max_iterations = 5
learning_rate = 1
for _ in range(max_iterations):
current_ee_pose = fk_func(current_joint_state)
if not position_only:
error = se3_error(desired_ee_pose, current_ee_pose)
jac = self.compute_jacobian(current_joint_state, fk_func)
else:
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
delta_angles = np.linalg.pinv(jac) @ error
current_joint_state[:-1] += learning_rate * delta_angles
if np.linalg.norm(error) < 5e-3:
return current_joint_state
return current_joint_state
if __name__ == "__main__":
import time
def run_test(robot_type):
"""Run test suite for a specific robot type."""
print(f"\n--- Testing {robot_type.upper()} Robot ---")
# Initialize kinematics for this robot
robot = RobotKinematics(robot_type)
# Test 1: Forward kinematics consistency
print("Test 1: Forward kinematics consistency")
test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees
# Calculate FK for different joints
shoulder_pose = robot.fk_shoulder(test_angles)
humerus_pose = robot.fk_humerus(test_angles)
forearm_pose = robot.fk_forearm(test_angles)
wrist_pose = robot.fk_wrist(test_angles)
gripper_pose = robot.fk_gripper(test_angles)
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
distances = [
np.linalg.norm(shoulder_pose[:3, 3]),
np.linalg.norm(humerus_pose[:3, 3]),
np.linalg.norm(forearm_pose[:3, 3]),
np.linalg.norm(wrist_pose[:3, 3]),
np.linalg.norm(gripper_pose[:3, 3]),
np.linalg.norm(gripper_tip_pose[:3, 3]),
]
# Check if distances generally increase along the chain
is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1))
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}")
# Test 2: Jacobian computation
print("Test 2: Jacobian computation")
jacobian = robot.compute_jacobian(test_angles)
positional_jacobian = robot.compute_positional_jacobian(test_angles)
# Check shapes
jacobian_shape_ok = jacobian.shape == (6, 5)
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}")
# Test 3: Inverse kinematics
print("Test 3: Inverse kinematics (position only)")
# Generate target pose from known joint angles
original_angles = np.array([10, 20, 30, -10, 5, 0])
target_pose = robot.fk_gripper(original_angles)
# Start IK from a different position
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
# Measure IK performance
start_time = time.time()
computed_angles = robot.ik(initial_guess.copy(), target_pose)
ik_time = time.time() - start_time
# Compute resulting pose from IK solution
result_pose = robot.fk_gripper(computed_angles)
# Calculate position error
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
passed = pos_error < 0.01 # Accept errors less than 1cm
print(f" IK computation time: {ik_time:.4f} seconds")
print(f" Position error: {pos_error:.4f}")
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
# Run tests for all robot types
results = {}
for robot_type in ["koch", "so100", "moss"]:
results[robot_type] = run_test(robot_type)
# Print overall summary
print("\n=== Test Summary ===")
all_passed = all(results.values())
for robot_type, passed in results.items():
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")

File diff suppressed because it is too large Load Diff

View File

@@ -1,80 +0,0 @@
import logging
from multiprocessing import Event, Queue
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def __init__(
self,
shutdown_event: Event, # type: ignore
parameters_queue: Queue,
seconds_between_pushes: float,
transition_queue: Queue,
interaction_message_queue: Queue,
):
self.shutdown_event = shutdown_event
self.parameters_queue = parameters_queue
self.seconds_between_pushes = seconds_between_pushes
self.transition_queue = transition_queue
self.interaction_message_queue = interaction_message_queue
def StreamParameters(self, request, context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
while not self.shutdown_event.is_set():
logging.info("[LEARNER] Push parameters to the Actor")
buffer = self.parameters_queue.get()
yield from send_bytes_in_chunks(
buffer,
hilserl_pb2.Parameters,
log_prefix="[LEARNER] Sending parameters",
silent=True,
)
logging.info("[LEARNER] Parameters sent")
self.shutdown_event.wait(self.seconds_between_pushes)
logging.info("[LEARNER] Stream parameters finished")
return hilserl_pb2.Empty()
def SendTransitions(self, request_iterator, _context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
receive_bytes_in_chunks(
request_iterator,
self.transition_queue,
self.shutdown_event,
log_prefix="[LEARNER] transitions",
)
logging.debug("[LEARNER] Finished receiving transitions")
return hilserl_pb2.Empty()
def SendInteractions(self, request_iterator, _context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive interactions from the Actor")
receive_bytes_in_chunks(
request_iterator,
self.interaction_message_queue,
self.shutdown_event,
log_prefix="[LEARNER] interactions",
)
logging.debug("[LEARNER] Finished receiving interactions")
return hilserl_pb2.Empty()
def Ready(self, request, context): # noqa: N802
return hilserl_pb2.Empty()

View File

@@ -1,142 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event, Queue
from typing import Any
import torch
from lerobot.scripts.server import hilserl_pb2
from lerobot.scripts.server.utils import Transition
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
def bytes_buffer_size(buffer: io.BytesIO) -> int:
buffer.seek(0, io.SEEK_END)
result = buffer.tell()
buffer.seek(0)
return result
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging_method = logging.info if not silent else logging.debug
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
bytes_buffer = io.BytesIO()
step = 0
logging.info(f"{log_prefix} Starting receiver")
for item in iterator:
logging.debug(f"{log_prefix} Received item")
if shutdown_event.is_set():
logging.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step 0")
step = 0
continue
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logging.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
queue.put(bytes_buffer.getvalue())
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
step = 0
logging.debug(f"{log_prefix} Queue updated")
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer.getvalue()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer, weights_only=False) # nosec B614: Using weights_only=False relies on pickle which has security implications.
# This is currently safe as we only deserialize trusted internal data.
# TODO: Verify if weights_only=True would work for our use case (safer default in torch 2.6+)
def python_object_to_bytes(python_object: Any) -> bytes:
return pickle.dumps(python_object)
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
# Add validation checks here
return obj
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
transitions = torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
# Add validation checks here
return transitions
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()

View File

@@ -1,141 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
import sys
from queue import Empty
from typing import TypedDict
import torch
from torch.multiprocessing import Queue
shutdown_event_counter = 0
def setup_process_handlers(use_threads: bool) -> any:
if use_threads:
from threading import Event
else:
from multiprocessing import Event
shutdown_event = Event()
# Define signal handler
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
global shutdown_event_counter
shutdown_event_counter += 1
if shutdown_event_counter > 1:
logging.info("Force shutdown")
sys.exit(1)
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
return shutdown_event
def get_last_item_from_queue(queue: Queue):
item = queue.get()
counter = 1
# Drain queue and keep only the most recent parameters
try:
while True:
item = queue.get_nowait()
counter += 1
except Empty:
pass
logging.debug(f"Drained {counter} items from queue")
return item
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: float
next_state: dict[str, torch.Tensor]
done: bool
truncated: bool
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
device = torch.device(device)
non_blocking = device.type == "cuda"
# Move state tensors to device
transition["state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
}
# Move action to device
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
# Move reward and done if they are tensors
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
# Move next_state tensors to device
transition["next_state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
}
# Move complementary_info tensors if present
if transition.get("complementary_info") is not None:
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
return transition
def move_state_dict_to_device(state_dict, device="cpu"):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
"""
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
else:
return state_dict

View File

@@ -63,13 +63,13 @@ dependencies = [
"opencv-python-headless>=4.9.0",
"packaging>=24.2",
"av>=14.2.0",
"pymunk>=6.6.0",
"pymunk>=6.6.0,<7.0.0",
"pynput>=1.7.7",
"pyzmq>=26.2.1",
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"torch>=2.2.1,<2.7",
"torchcodec==0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torch>=2.2.1",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchvision>=0.21.0",
"wandb>=0.16.3",
"zarr>=2.17.0",
@@ -84,7 +84,6 @@ dora = [
]
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
hilserl = ["transformers>=4.48", "gym-hil>=0.1.3", "protobuf>=5.29.3", "grpcio>=1.70.0"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
pi0 = ["transformers>=4.48.0"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
@@ -99,40 +98,13 @@ umi = ["imagecodecs>=2024.1.1"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
[tool.poetry]
requires-poetry = ">=2.1"
[tool.ruff]
line-length = 110
target-version = "py310"
exclude = [
"tests/data",
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
"*_pb2.py",
"*_pb2_grpc.py",
]
exclude = ["tests/artifacts/**/*.safetensors"]
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0389a716d51c1c615fb2a3bfa386d89f00b0deca08c4fa21b23e020a939d0213
oid sha256:6b1e600768a8771c5fe650e038a1193597e3810f032041b2a0d021e4496381c1
size 3686488

View File

@@ -28,7 +28,7 @@ from lerobot.common.datasets.transforms import (
from lerobot.common.utils.random_utils import seeded_context
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
DATASET_REPO_ID = "lerobot/aloha_static_cups_open"
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0dc691503e7d90b2086bb408e89a65f772ce5ee6e3562ef8c127bcb09bd90851
oid sha256:9d4ebab73eabddc58879a4e770289d19e00a1a4cf2fa5fa33cd3a3246992bc90
size 40551392

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
size 5104

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
oid sha256:1a7a8b1a457149109f843c32bcbb047d09de2201847b9b79f7501b447f77ecf4
size 31672

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
oid sha256:5e6ce85296b2009e7c2060d336c0429b1c7197d9adb159e7df0ba18003067b36
size 68

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
size 33400

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
size 515400

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6
size 31672

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd
size 68

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
size 33400

View File

@@ -14,11 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import traceback
import pytest
import torch
from serial import SerialException
from lerobot import available_cameras, available_motors, available_robots
@@ -88,19 +86,3 @@ def patch_builtins_input(monkeypatch):
print(text)
monkeypatch.setattr("builtins.input", print_text)
def pytest_addoption(parser):
parser.addoption(
"--seed",
action="store",
default="42",
help="Set random seed for reproducibility",
)
@pytest.fixture(autouse=True)
def set_random_seed(request):
seed = int(request.config.getoption("--seed"))
random.seed(seed) # Python random
torch.manual_seed(seed) # PyTorch

View File

@@ -16,6 +16,7 @@
import pytest
import torch
from packaging import version
from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
@@ -253,7 +254,14 @@ def test_backward_compatibility_single_transforms(
@require_x86_64_kernel
@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("2.7.0"),
reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior",
)
def test_backward_compatibility_default_config(img_tensor, default_transforms):
# NOTE: PyTorch versions have different randomness, it might break this test.
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
cfg = ImageTransformsConfig(enable=True)
default_tf = ImageTransforms(cfg)

View File

@@ -21,7 +21,6 @@ from lerobot.common.constants import (
from lerobot.common.optim.optimizers import (
AdamConfig,
AdamWConfig,
MultiAdamConfig,
SGDConfig,
load_optimizer_state,
save_optimizer_state,
@@ -34,21 +33,13 @@ from lerobot.common.optim.optimizers import (
(AdamConfig, torch.optim.Adam),
(AdamWConfig, torch.optim.AdamW),
(SGDConfig, torch.optim.SGD),
(MultiAdamConfig, dict),
],
)
def test_optimizer_build(config_cls, expected_class, model_params):
config = config_cls()
if config_cls == MultiAdamConfig:
params_dict = {"default": model_params}
optimizer = config.build(params_dict)
assert isinstance(optimizer, expected_class)
assert isinstance(optimizer["default"], torch.optim.Adam)
assert optimizer["default"].defaults["lr"] == config.lr
else:
optimizer = config.build(model_params)
assert isinstance(optimizer, expected_class)
assert optimizer.defaults["lr"] == config.lr
optimizer = config.build(model_params)
assert isinstance(optimizer, expected_class)
assert optimizer.defaults["lr"] == config.lr
def test_save_optimizer_state(optimizer, tmp_path):
@@ -63,180 +54,3 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
@pytest.fixture
def base_params_dict():
return {
"actor": [torch.nn.Parameter(torch.randn(10, 10))],
"critic": [torch.nn.Parameter(torch.randn(5, 5))],
"temperature": [torch.nn.Parameter(torch.randn(3, 3))],
}
@pytest.mark.parametrize(
"config_params, expected_values",
[
# Test 1: Basic configuration with different learning rates
(
{
"lr": 1e-3,
"weight_decay": 1e-4,
"optimizer_groups": {
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
},
{
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
},
),
# Test 2: Different weight decays and beta values
(
{
"lr": 1e-3,
"weight_decay": 1e-4,
"optimizer_groups": {
"actor": {"lr": 1e-4, "weight_decay": 1e-5},
"critic": {"lr": 5e-4, "weight_decay": 1e-6},
"temperature": {"lr": 2e-3, "betas": (0.95, 0.999)},
},
},
{
"actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)},
"critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)},
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)},
},
),
# Test 3: Epsilon parameter customization
(
{
"lr": 1e-3,
"weight_decay": 1e-4,
"optimizer_groups": {
"actor": {"lr": 1e-4, "eps": 1e-6},
"critic": {"lr": 5e-4, "eps": 1e-7},
"temperature": {"lr": 2e-3, "eps": 1e-8},
},
},
{
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6},
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7},
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8},
},
),
],
)
def test_multi_adam_configuration(base_params_dict, config_params, expected_values):
# Create config with the given parameters
config = MultiAdamConfig(**config_params)
optimizers = config.build(base_params_dict)
# Verify optimizer count and keys
assert len(optimizers) == len(expected_values)
assert set(optimizers.keys()) == set(expected_values.keys())
# Check that all optimizers are Adam instances
for opt in optimizers.values():
assert isinstance(opt, torch.optim.Adam)
# Verify hyperparameters for each optimizer
for name, expected in expected_values.items():
optimizer = optimizers[name]
for param, value in expected.items():
assert optimizer.defaults[param] == value
@pytest.fixture
def multi_optimizers(base_params_dict):
config = MultiAdamConfig(
lr=1e-3,
optimizer_groups={
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
)
return config.build(base_params_dict)
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
# Save optimizer states
save_optimizer_state(multi_optimizers, tmp_path)
# Verify that directories were created for each optimizer
for name in multi_optimizers:
assert (tmp_path / name).is_dir()
assert (tmp_path / name / OPTIMIZER_STATE).is_file()
assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file()
def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path):
# Option 1: Add a minimal backward pass to populate optimizer states
for name, params in base_params_dict.items():
if name in multi_optimizers:
# Create a dummy loss and do backward
dummy_loss = params[0].sum()
dummy_loss.backward()
# Perform an optimization step
multi_optimizers[name].step()
# Zero gradients for next steps
multi_optimizers[name].zero_grad()
# Save optimizer states
save_optimizer_state(multi_optimizers, tmp_path)
# Create new optimizers with the same config
config = MultiAdamConfig(
lr=1e-3,
optimizer_groups={
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
)
new_optimizers = config.build(base_params_dict)
# Load optimizer states
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
# Verify state dictionaries match
for name in multi_optimizers:
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
"""Test saving and loading optimizer states even when the state is empty (no backward pass)."""
# Create config and build optimizers
config = MultiAdamConfig(
lr=1e-3,
optimizer_groups={
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
)
optimizers = config.build(base_params_dict)
# Save optimizer states without any backward pass (empty state)
save_optimizer_state(optimizers, tmp_path)
# Create new optimizers with the same config
new_optimizers = config.build(base_params_dict)
# Load optimizer states
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
# Verify hyperparameters match even with empty state
for name, optimizer in optimizers.items():
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
# Verify state dictionaries match (they will be empty)
torch.testing.assert_close(
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
)

View File

@@ -37,7 +37,6 @@ def test_diffuser_scheduler(optimizer):
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict
@@ -56,7 +55,6 @@ def test_vqbet_scheduler(optimizer):
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict
@@ -77,7 +75,6 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict

View File

@@ -1,123 +0,0 @@
import torch
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.common.policies.reward_model.modeling_classifier import ClassifierOutput
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from tests.utils import require_package
def test_classifier_output():
output = ClassifierOutput(
logits=torch.tensor([1, 2, 3]),
probabilities=torch.tensor([0.1, 0.2, 0.3]),
hidden_states=None,
)
assert (
f"{output}"
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
)
@require_package("transformers")
def test_binary_classifier_with_default_params():
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
config.num_cameras = 1
classifier = Classifier(config)
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
}
images, labels = classifier.extract_images_and_labels(input)
assert len(images) == 1
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
assert labels.shape == torch.Size([batch_size])
output = classifier.predict(images)
assert output is not None
assert output.logits.size() == torch.Size([batch_size])
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 256])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@require_package("transformers")
def test_multiclass_classifier():
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
num_classes = 5
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
}
config.num_cameras = 1
config.num_classes = num_classes
classifier = Classifier(config)
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.rand((batch_size, num_classes)),
}
images, labels = classifier.extract_images_and_labels(input)
assert len(images) == 1
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
assert labels.shape == torch.Size([batch_size, num_classes])
output = classifier.predict(images)
assert output is not None
assert output.logits.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 256])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@require_package("transformers")
def test_default_device():
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
assert config.device == "cpu"
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")
@require_package("transformers")
def test_explicit_device_setup():
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig(device="cpu")
assert config.device == "cpu"
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")

View File

@@ -20,6 +20,7 @@ from pathlib import Path
import einops
import pytest
import torch
from packaging import version
from safetensors.torch import load_file
from lerobot import available_policies
@@ -408,7 +409,16 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
is out of date. For example, some PyTorch versions have different randomness, see this PR:
https://github.com/huggingface/lerobot/pull/1127.
"""
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")

View File

@@ -1,217 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from lerobot.common.policies.sac.configuration_sac import (
ActorLearnerConfig,
ActorNetworkConfig,
ConcurrencyConfig,
CriticNetworkConfig,
PolicyConfig,
SACConfig,
)
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def test_sac_config_default_initialization():
config = SACConfig()
assert config.normalization_mapping == {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
assert config.dataset_stats == {
"observation.image": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
}
# Basic parameters
assert config.device == "cpu"
assert config.storage_device == "cpu"
assert config.discount == 0.99
assert config.temperature_init == 1.0
assert config.num_critics == 2
# Architecture specifics
assert config.vision_encoder_name is None
assert config.freeze_vision_encoder is True
assert config.image_encoder_hidden_dim == 32
assert config.shared_encoder is True
assert config.num_discrete_actions is None
assert config.image_embedding_pooling_dim == 8
# Training parameters
assert config.online_steps == 1000000
assert config.online_env_seed == 10000
assert config.online_buffer_capacity == 100000
assert config.offline_buffer_capacity == 100000
assert config.async_prefetch is False
assert config.online_step_before_learning == 100
assert config.policy_update_freq == 1
# SAC algorithm parameters
assert config.num_subsample_critics is None
assert config.critic_lr == 3e-4
assert config.actor_lr == 3e-4
assert config.temperature_lr == 3e-4
assert config.critic_target_update_weight == 0.005
assert config.utd_ratio == 1
assert config.state_encoder_hidden_dim == 256
assert config.latent_dim == 256
assert config.target_entropy is None
assert config.use_backup_entropy is True
assert config.grad_clip_norm == 40.0
# Dataset stats defaults
expected_dataset_stats = {
"observation.image": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
}
assert config.dataset_stats == expected_dataset_stats
# Critic network configuration
assert config.critic_network_kwargs.hidden_dims == [256, 256]
assert config.critic_network_kwargs.activate_final is True
assert config.critic_network_kwargs.final_activation is None
# Actor network configuration
assert config.actor_network_kwargs.hidden_dims == [256, 256]
assert config.actor_network_kwargs.activate_final is True
# Policy configuration
assert config.policy_kwargs.use_tanh_squash is True
assert config.policy_kwargs.log_std_min == 1e-5
assert config.policy_kwargs.log_std_max == 10.0
assert config.policy_kwargs.init_final == 0.05
# Discrete critic network configuration
assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256]
assert config.discrete_critic_network_kwargs.activate_final is True
assert config.discrete_critic_network_kwargs.final_activation is None
# Actor learner configuration
assert config.actor_learner_config.learner_host == "127.0.0.1"
assert config.actor_learner_config.learner_port == 50051
assert config.actor_learner_config.policy_parameters_push_frequency == 4
# Concurrency configuration
assert config.concurrency.actor == "threads"
assert config.concurrency.learner == "threads"
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
assert isinstance(config.policy_kwargs, PolicyConfig)
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
assert isinstance(config.concurrency, ConcurrencyConfig)
def test_critic_network_kwargs():
config = CriticNetworkConfig()
assert config.hidden_dims == [256, 256]
assert config.activate_final is True
assert config.final_activation is None
def test_actor_network_kwargs():
config = ActorNetworkConfig()
assert config.hidden_dims == [256, 256]
assert config.activate_final is True
def test_policy_kwargs():
config = PolicyConfig()
assert config.use_tanh_squash is True
assert config.log_std_min == 1e-5
assert config.log_std_max == 10.0
assert config.init_final == 0.05
def test_actor_learner_config():
config = ActorLearnerConfig()
assert config.learner_host == "127.0.0.1"
assert config.learner_port == 50051
assert config.policy_parameters_push_frequency == 4
def test_concurrency_config():
config = ConcurrencyConfig()
assert config.actor == "threads"
assert config.learner == "threads"
def test_sac_config_custom_initialization():
config = SACConfig(
device="cpu",
discount=0.95,
temperature_init=0.5,
num_critics=3,
)
assert config.device == "cpu"
assert config.discount == 0.95
assert config.temperature_init == 0.5
assert config.num_critics == 3
def test_validate_features():
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
config.validate_features()
def test_validate_features_missing_observation():
config = SACConfig(
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(
ValueError, match="You must provide either 'observation.state' or an image observation"
):
config.validate_features()
def test_validate_features_missing_action():
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
config.validate_features()

View File

@@ -1,519 +0,0 @@
import math
import pytest
import torch
from torch import Tensor, nn
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.common.utils.random_utils import seeded_context
from lerobot.configs.types import FeatureType, PolicyFeature
try:
import transformers # noqa: F401
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
def test_mlp_with_default_args():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(10)
y = mlp(x)
assert y.shape == (256,)
def test_mlp_with_batch_dim():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(2, 10)
y = mlp(x)
assert y.shape == (2, 256)
def test_forward_with_empty_hidden_dims():
mlp = MLP(input_dim=10, hidden_dims=[])
x = torch.randn(1, 10)
assert mlp(x).shape == (1, 10)
def test_mlp_with_dropout():
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 11)
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
assert drop_out_layers_count == 2
def test_mlp_with_custom_final_activation():
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 256)
assert (y >= -1).all() and (y <= 1).all()
def test_sac_policy_with_default_args():
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
SACPolicy()
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
return {
"observation.state": torch.randn(batch_size, state_dim),
}
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
return {
"observation.image": torch.randn(batch_size, 3, 84, 84),
"observation.state": torch.randn(batch_size, state_dim),
}
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
return torch.randn(batch_size, action_dim)
def create_default_train_batch(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
"action": create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_state(batch_size, state_dim),
"next_state": create_dummy_state(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_train_batch_with_visual_input(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
"action": create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_with_visual_input(batch_size, state_dim),
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
"observation.state": torch.randn(batch_size, state_dim),
}
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
"observation.state": torch.randn(batch_size, state_dim),
"observation.image": torch.randn(batch_size, 3, 84, 84),
}
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
"""Create optimizers for the SAC policy."""
optimizer_actor = torch.optim.Adam(
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=policy.config.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(),
lr=policy.config.critic_lr,
)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha],
lr=policy.config.critic_lr,
)
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if has_discrete_action:
optimizers["discrete_critic"] = torch.optim.Adam(
params=policy.discrete_critic.parameters(),
lr=policy.config.critic_lr,
)
return optimizers
def create_default_config(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
action_dim = continuous_action_dim
if has_discrete_action:
action_dim += 1
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
"observation.state": {
"min": [0.0] * state_dim,
"max": [1.0] * state_dim,
},
"action": {
"min": [0.0] * continuous_action_dim,
"max": [1.0] * continuous_action_dim,
},
},
)
config.validate_features()
return config
def create_config_with_visual_input(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
config = create_default_config(
state_dim=state_dim,
continuous_action_dim=continuous_action_dim,
has_discrete_action=has_discrete_action,
)
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats["observation.image"] = {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
}
# Let make tests a little bit faster
config.state_encoder_hidden_dim = 32
config.latent_dim = 32
config.validate_features()
return config
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
assert temperature_loss.item() is not None
assert temperature_loss.shape == ()
temperature_loss.backward()
optimizers["temperature"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
assert temperature_loss.item() is not None
assert temperature_loss.shape == ()
temperature_loss.backward()
optimizers["temperature"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
# Let's check best candidates for pretrained encoders
@pytest.mark.parametrize(
"batch_size,state_dim,action_dim,vision_encoder_name",
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
def test_sac_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.vision_encoder_name = vision_encoder_name
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
def test_sac_policy_with_shared_encoder():
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.shared_encoder = True
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
def test_sac_policy_with_discrete_critic():
batch_size = 2
continuous_action_dim = 9
full_action_dim = continuous_action_dim + 1 # the last action is discrete
state_dim = 10
config = create_config_with_visual_input(
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
)
num_discrete_actions = 5
config.num_discrete_actions = num_discrete_actions
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
)
policy.train()
optimizers = make_optimizers(policy, has_discrete_action=True)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
assert discrete_critic_loss.item() is not None
assert discrete_critic_loss.shape == ()
discrete_critic_loss.backward()
optimizers["discrete_critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, full_action_dim)
discrete_actions = selected_action[:, -1].long()
discrete_action_values = set(discrete_actions.tolist())
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
)
def test_sac_policy_with_default_entropy():
config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config)
assert policy.target_entropy == -5.0
def test_sac_policy_default_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
policy = SACPolicy(config=config)
assert policy.target_entropy == -3.0
def test_sac_policy_with_predefined_entropy():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.target_entropy = -3.5
policy = SACPolicy(config=config)
assert policy.target_entropy == pytest.approx(-3.5)
def test_sac_policy_update_temperature():
config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config)
assert policy.temperature == pytest.approx(1.0)
policy.log_alpha.data = torch.tensor([math.log(0.1)])
policy.update_temperature()
assert policy.temperature == pytest.approx(0.1)
def test_sac_policy_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
policy = SACPolicy(config=config)
policy.train()
for p in policy.critic_ensemble.parameters():
p.data = torch.ones_like(p.data)
policy.update_target_networks()
for p in policy.critic_target.parameters():
assert torch.allclose(p.data, torch.ones_like(p.data)), (
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
)
@pytest.mark.parametrize("num_critics", [1, 3])
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_critics = num_critics
policy = SACPolicy(config=config)
policy.train()
assert len(policy.critic_ensemble.critics) == num_critics
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
def test_sac_policy_save_and_load(tmp_path):
root = tmp_path / "test_sac_save_and_load"
state_dim = 10
action_dim = 10
batch_size = 2
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
with torch.no_grad():
with seeded_context(12):
# Collect policy values before saving
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
actions = policy.select_action(observation_batch)
with seeded_context(12):
# Collect policy values after loading
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
# Compare values before and after saving and loading
# They should be the same
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
assert torch.allclose(actor_loss, loaded_actor_loss)
assert torch.allclose(temperature_loss, loaded_temperature_loss)
assert torch.allclose(actions, loaded_actions)

View File

@@ -1,599 +0,0 @@
import sys
from typing import Callable, Optional
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.scripts.server.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
from tests.fixtures.constants import DUMMY_REPO_ID
def state_dims() -> list[str]:
return ["observation.image", "observation.state"]
@pytest.fixture
def replay_buffer() -> ReplayBuffer:
return create_empty_replay_buffer()
def clone_state(state: dict) -> dict:
return {k: v.clone() for k, v in state.items()}
def create_empty_replay_buffer(
optimize_memory: bool = False,
use_drq: bool = False,
image_augmentation_function: Optional[Callable] = None,
) -> ReplayBuffer:
buffer_capacity = 10
device = "cpu"
return ReplayBuffer(
buffer_capacity,
device,
state_dims(),
optimize_memory=optimize_memory,
use_drq=use_drq,
image_augmentation_function=image_augmentation_function,
)
def create_random_image() -> torch.Tensor:
return torch.rand(3, 84, 84)
def create_dummy_transition() -> dict:
return {
"observation.image": create_random_image(),
"action": torch.randn(4),
"reward": torch.tensor(1.0),
"observation.state": torch.randn(
10,
),
"done": torch.tensor(False),
"truncated": torch.tensor(False),
"complementary_info": {},
}
def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]:
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
dummy_state_2 = create_dummy_state()
dummy_action_2 = create_dummy_action()
dummy_state_3 = create_dummy_state()
dummy_action_3 = create_dummy_action()
dummy_state_4 = create_dummy_state()
dummy_action_4 = create_dummy_action()
replay_buffer = create_empty_replay_buffer()
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
root = tmp_path / "test"
return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer)
def create_dummy_state() -> dict:
return {
"observation.image": create_random_image(),
"observation.state": torch.randn(
10,
),
}
def get_tensor_memory_consumption(tensor):
return tensor.nelement() * tensor.element_size()
def get_tensors_memory_consumption(obj, visited_addresses):
total_size = 0
address = id(obj)
if address in visited_addresses:
return 0
visited_addresses.add(address)
if isinstance(obj, torch.Tensor):
return get_tensor_memory_consumption(obj)
elif isinstance(obj, (list, tuple)):
for item in obj:
total_size += get_tensors_memory_consumption(item, visited_addresses)
elif isinstance(obj, dict):
for value in obj.values():
total_size += get_tensors_memory_consumption(value, visited_addresses)
elif hasattr(obj, "__dict__"):
# It's an object, we need to get the size of the attributes
for _, attr in vars(obj).items():
total_size += get_tensors_memory_consumption(attr, visited_addresses)
return total_size
def get_object_memory(obj):
# Track visited addresses to avoid infinite loops
# and cases when two properties point to the same object
visited_addresses = set()
# Get the size of the object in bytes
total_size = sys.getsizeof(obj)
# Get the size of the tensor attributes
total_size += get_tensors_memory_consumption(obj, visited_addresses)
return total_size
def create_dummy_action() -> torch.Tensor:
return torch.randn(4)
def dict_properties() -> list:
return ["state", "next_state"]
@pytest.fixture
def dummy_state() -> dict:
return create_dummy_state()
@pytest.fixture
def next_dummy_state() -> dict:
return create_dummy_state()
@pytest.fixture
def dummy_action() -> torch.Tensor:
return torch.randn(4)
def test_empty_buffer_sample_raises_error(replay_buffer):
assert len(replay_buffer) == 0, "Replay buffer should be empty."
assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10."
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
replay_buffer.sample(1)
def test_zero_capacity_buffer_raises_error():
with pytest.raises(ValueError, match="Capacity must be greater than 0."):
ReplayBuffer(0, "cpu", ["observation", "next_observation"])
def test_add_transition(replay_buffer, dummy_state, dummy_action):
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding."
assert torch.equal(replay_buffer.actions[0], dummy_action), (
"Action should be equal to the first transition."
)
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
assert not replay_buffer.dones[0], "Done should be False for the first transition."
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
for dim in state_dims():
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
"Observation should be equal to the first transition."
)
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), (
"Next observation should be equal to the first transition."
)
def test_add_over_capacity():
replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"])
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
dummy_state_2 = create_dummy_state()
dummy_action_2 = create_dummy_action()
dummy_state_3 = create_dummy_state()
dummy_action_3 = create_dummy_action()
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3."
for dim in state_dims():
assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), (
"Observation should be equal to the first transition."
)
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), (
"Next observation should be equal to the first transition."
)
assert torch.equal(replay_buffer.actions[0], dummy_action_3), (
"Action should be equal to the last transition."
)
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
assert replay_buffer.dones[0], "Done should be True for the first transition."
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
def test_sample_from_empty_buffer(replay_buffer):
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
replay_buffer.sample(1)
def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action):
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
got_batch_transition = replay_buffer.sample(1)
expected_batch_transition = BatchTransition(
state=clone_state(dummy_state),
action=dummy_action.clone(),
reward=1.0,
next_state=clone_state(next_dummy_state),
done=False,
truncated=False,
)
for buffer_property in dict_properties():
for k, v in expected_batch_transition[buffer_property].items():
got_state = got_batch_transition[buffer_property][k]
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
assert got_state.device.type == "cpu", f"{k} should be on cpu."
assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition."
for key, _value in expected_batch_transition.items():
if key in dict_properties():
continue
got_value = got_batch_transition[key]
v_tensor = expected_batch_transition[key]
if not isinstance(v_tensor, torch.Tensor):
v_tensor = torch.tensor(v_tensor)
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
assert got_value.device.type == "cpu", f"{key} should be on cpu."
assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition."
def test_sample_with_batch_bigger_than_buffer_size(
replay_buffer, dummy_state, next_dummy_state, dummy_action
):
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
got_batch_transition = replay_buffer.sample(10)
expected_batch_transition = BatchTransition(
state=dummy_state,
action=dummy_action,
reward=1.0,
next_state=next_dummy_state,
done=False,
truncated=False,
)
for buffer_property in dict_properties():
for k in expected_batch_transition[buffer_property]:
got_state = got_batch_transition[buffer_property][k]
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
for key in expected_batch_transition:
if key in dict_properties():
continue
got_value = got_batch_transition[key]
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
def test_sample_batch(replay_buffer):
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
dummy_state_2 = create_dummy_state()
dummy_action_2 = create_dummy_action()
dummy_state_3 = create_dummy_state()
dummy_action_3 = create_dummy_action()
dummy_state_4 = create_dummy_state()
dummy_action_4 = create_dummy_action()
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False)
replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True)
replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True)
dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4]
dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4]
got_batch_transition = replay_buffer.sample(3)
for buffer_property in dict_properties():
for k in got_batch_transition[buffer_property]:
got_state = got_batch_transition[buffer_property][k]
assert got_state.shape[0] == 3, f"{k} should have 3 transition."
for got_state_item in got_state:
assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), (
f"{k} should be equal to one of the dummy states."
)
for got_action_item in got_batch_transition["action"]:
assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), (
"Actions should be equal to the dummy actions."
)
for k in got_batch_transition:
if k in dict_properties() or k == "complementary_info":
continue
got_value = got_batch_transition[k]
assert got_value.shape[0] == 3, f"{k} should have 3 transition."
def test_to_lerobot_dataset_with_empty_buffer(replay_buffer):
with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."):
replay_buffer.to_lerobot_dataset("dummy_repo")
def test_to_lerobot_dataset(tmp_path):
ds, buffer = create_dataset_from_replay_buffer(tmp_path)
assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer"
assert ds.fps == 1, "FPS should be 1"
assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id"
for dim in state_dims():
assert dim in ds.features
assert ds.features[dim]["shape"] == buffer.states[dim][0].shape
assert ds.num_episodes == 2
assert ds.num_frames == 4
for j, value in enumerate(ds):
print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j]))
for i in range(len(ds)):
for feature, value in ds[i].items():
if feature == "action":
assert torch.equal(value, buffer.actions[i])
elif feature == "next.reward":
assert torch.equal(value, buffer.rewards[i])
elif feature == "next.done":
assert torch.equal(value, buffer.dones[i])
elif feature == "observation.image":
# Tenssor -> numpy is not precise, so we have some diff there
# TODO: Check and fix it
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
elif feature == "observation.state":
assert torch.equal(value, buffer.states["observation.state"][i])
def test_from_lerobot_dataset(tmp_path):
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
dummy_state_2 = create_dummy_state()
dummy_action_2 = create_dummy_action()
dummy_state_3 = create_dummy_state()
dummy_action_3 = create_dummy_action()
dummy_state_4 = create_dummy_state()
dummy_action_4 = create_dummy_action()
replay_buffer = create_empty_replay_buffer()
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
root = tmp_path / "test"
ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root)
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
)
# Check only the part of the buffer that's actually filled with data
assert torch.equal(
reconverted_buffer.actions[: len(replay_buffer)],
replay_buffer.actions[: len(replay_buffer)],
), "Actions from converted buffer should be equal to the original replay buffer."
assert torch.equal(
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
), "Rewards from converted buffer should be equal to the original replay buffer."
assert torch.equal(
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
), "Dones from converted buffer should be equal to the original replay buffer."
# Lerobot DS haven't supported truncateds yet
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
"Truncateds from converted buffer should be equal False"
)
assert torch.equal(
replay_buffer.states["observation.state"][: len(replay_buffer)],
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
), "State should be the same after converting to dataset and return back"
for i in range(4):
torch.testing.assert_close(
replay_buffer.states["observation.image"][i],
reconverted_buffer.states["observation.image"][i],
rtol=0.4,
atol=0.004,
)
# The 2, 3 frames have done flag, so their values will be equal to the current state
for i in range(2):
# In the current implementation we take the next state from the `states` and ignore `next_states`
next_index = (i + 1) % 4
torch.testing.assert_close(
replay_buffer.states["observation.image"][next_index],
reconverted_buffer.next_states["observation.image"][i],
rtol=0.4,
atol=0.004,
)
for i in range(2, 4):
assert torch.equal(
replay_buffer.states["observation.state"][i],
reconverted_buffer.next_states["observation.state"][i],
)
def test_buffer_sample_alignment():
# Initialize buffer
buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu")
# Fill buffer with patterned data
for i in range(100):
signature = float(i) / 100.0
state = {"state_value": torch.tensor([[signature]]).float()}
action = torch.tensor([[2.0 * signature]]).float()
reward = 3.0 * signature
is_end = (i + 1) % 10 == 0
if is_end:
next_state = {"state_value": torch.tensor([[signature]]).float()}
done = True
else:
next_signature = float(i + 1) / 100.0
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
done = False
buffer.add(state, action, reward, next_state, done, False)
# Sample and verify
batch = buffer.sample(50)
for i in range(50):
state_sig = batch["state"]["state_value"][i].item()
action_val = batch["action"][i].item()
reward_val = batch["reward"][i].item()
next_state_sig = batch["next_state"]["state_value"][i].item()
is_done = batch["done"][i].item() > 0.5
# Verify relationships
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
f"Action {action_val} should be 2x state signature {state_sig}"
)
assert abs(reward_val - 3.0 * state_sig) < 1e-4, (
f"Reward {reward_val} should be 3x state signature {state_sig}"
)
if is_done:
assert abs(next_state_sig - state_sig) < 1e-4, (
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
)
else:
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
valid_next = (
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
)
assert valid_next, (
f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}"
)
def test_memory_optimization():
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
dummy_state_2 = create_dummy_state()
dummy_action_2 = create_dummy_action()
dummy_state_3 = create_dummy_state()
dummy_action_3 = create_dummy_action()
dummy_state_4 = create_dummy_state()
dummy_action_4 = create_dummy_action()
replay_buffer = create_empty_replay_buffer()
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
optimized_replay_buffer = create_empty_replay_buffer(True)
optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True)
assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), (
"Optimized replay buffer should be smaller than the original replay buffer"
)
def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action):
def dummy_image_augmentation_function(x):
return torch.ones_like(x) * 10
replay_buffer = create_empty_replay_buffer(
use_drq=True, image_augmentation_function=dummy_image_augmentation_function
)
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
sampled_transitions = replay_buffer.sample(1)
assert torch.all(sampled_transitions["state"]["observation.image"] == 10), (
"Image augmentations should be applied"
)
assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), (
"Image augmentations should be applied"
)
def test_check_image_augmentations_with_drq_and_default_image_augmentation_function(
dummy_state, dummy_action
):
replay_buffer = create_empty_replay_buffer(use_drq=True)
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
# Let's check that it doesn't fail and shapes are correct
sampled_transitions = replay_buffer.sample(1)
assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84)
assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84)
def test_random_crop_vectorized_basic():
# Create a batch of 2 images with known patterns
batch_size, channels, height, width = 2, 3, 10, 8
images = torch.zeros((batch_size, channels, height, width))
# Fill with unique values for testing
for b in range(batch_size):
images[b] = b + 1
crop_size = (6, 4) # Smaller than original
cropped = random_crop_vectorized(images, crop_size)
# Check output shape
assert cropped.shape == (batch_size, channels, *crop_size)
# Check that values are preserved (should be either 1s or 2s for respective batches)
assert torch.all(cropped[0] == 1)
assert torch.all(cropped[1] == 2)
def test_random_crop_vectorized_invalid_size():
images = torch.zeros((2, 3, 10, 8))
# Test crop size larger than image
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
random_crop_vectorized(images, (12, 8))
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
random_crop_vectorized(images, (10, 10))