Compare commits

...

21 Commits

Author SHA1 Message Date
Francesco Capuano
2b5fce823f fix: camera and motors modules for mock robots 2025-04-23 20:13:38 +02:00
Francesco Capuano
2cce85b5dd fix: action chunks predicted using policy, and timed to observation used 2025-04-19 14:34:36 +02:00
Francesco Capuano
b2d003e6eb fix: client sends timed objects only, and uses lock to read & write robot status 2025-04-19 14:30:29 +02:00
Francesco Capuano
200ba1feb5 add: precommits ignore proto file 2025-04-19 14:18:01 +02:00
Francesco Capuano
0fc9a4341f fix: separate threads for obs streaming, action receiving & execution + action queue reconciliation 2025-04-17 21:09:58 +02:00
Francesco Capuano
d40e74f371 fix: streams inference process using LIFO on obs 2025-04-17 21:09:04 +02:00
Francesco Capuano
40237f5ea3 fix: ruff, get your hands off compiled files 2025-04-17 20:33:54 +02:00
Francesco Capuano
2bcdb57854 fix: bus ids 2025-04-17 20:02:59 +02:00
Francesco Capuano
e9ca1b612d fix: send obs, receives and queues actions chunk, overwrites queue periodically 2025-04-17 19:50:13 +02:00
Francesco Capuano
169babd621 fix: server predicts multiple actions for a given observation, VLA-like 2025-04-17 19:50:02 +02:00
Francesco Capuano
a9031ee1be add: server computes action, robot's daemon constantly reads it 2025-04-17 19:47:20 +02:00
Francesco Capuano
fc107a2c6e add: robot can send observations 2025-04-17 19:47:11 +02:00
Francesco Capuano
84fabbf4af add: grpc service between robot and remote policy server 2025-04-17 19:47:03 +02:00
Steven Palma
5322417c03 fix(examples): removes extra backtick (#948) 2025-04-09 17:44:32 +02:00
Steven Palma
4041f57943 feat(visualization): replace cv2 GUI with Rerun (and solves ffmpeg versioning issues) (#903) 2025-04-09 17:33:01 +02:00
Simon Alibert
2c86fea78a Switch typos pre-commit to mirror (#953) 2025-04-08 12:44:09 +02:00
pre-commit-ci[bot]
437fc29e12 [pre-commit.ci] pre-commit autoupdate (#871)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-04-08 06:58:46 +02:00
Junwu Zhang
aee86b4b18 typo fix: example_1 python script (#631)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-04-07 17:41:10 +02:00
mshukor
1c873df5c0 Support for PI0+FAST (#921)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com>
Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-04-04 11:51:11 +02:00
Steven Palma
145fe4cd17 fix(deps): avoid torchcodec in macos x86_64 (#925) 2025-04-01 15:51:59 +02:00
Mariusz Dubielecki
e004247ed4 docs: add tip for Mac users regarding Terminal permissions for keyboard (#917)
Signed-off-by: cranberrysoft <dubielecki.mariusz@gmail.com>
2025-03-31 09:44:05 +02:00
45 changed files with 2708 additions and 104 deletions

View File

@@ -36,8 +36,8 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/crate-ci/typos
rev: v1.30.2
- repo: https://github.com/adhtruong/mirrors-typos
rev: v1.31.1
hooks:
- id: typos
args: [--force-exclude]
@@ -48,7 +48,7 @@ repos:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.10
rev: v0.11.4
hooks:
- id: ruff
args: [--fix]
@@ -57,12 +57,12 @@ repos:
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.0
rev: v8.24.2
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.4.1
rev: v1.5.2
hooks:
- id: zizmor

View File

@@ -98,14 +98,14 @@ conda create -y -n lerobot python=3.10
conda activate lerobot
```
When using `miniconda`, if you don't have `ffmpeg` in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg
conda install ffmpeg -c conda-forge
```
Install 🤗 LeRobot:
```bash
pip install --no-binary=av -e .
pip install -e .
```
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
@@ -118,7 +118,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
For instance, to install 🤗 LeRobot with aloha and pusht, use:
```bash
pip install --no-binary=av -e ".[aloha, pusht]"
pip install -e ".[aloha, pusht]"
```
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with

View File

@@ -17,12 +17,21 @@
import argparse
import datetime as dt
import os
import time
from pathlib import Path
import cv2
import rerun as rr
# see https://rerun.io/docs/howto/visualization/limit-ram
RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int):
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int):
rr.init("lerobot_capture_camera_feed")
rr.spawn(memory_limit=RERUN_MEMORY_LIMIT)
now = dt.datetime.now()
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
if not capture_dir.exists():
@@ -39,24 +48,21 @@ def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
frame_index = 0
while True:
start_time = time.time()
while time.time() - start_time < duration:
ret, frame = cap.read()
if not ret:
print("Error: Could not read frame.")
break
cv2.imshow("Video Stream", frame)
rr.log("video/stream", rr.Image(frame.numpy()), static=True)
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
frame_index += 1
# Break the loop on 'q' key press
if cv2.waitKey(1) & 0xFF == ord("q"):
break
# Release the capture and destroy all windows
# Release the capture
cap.release()
cv2.destroyAllWindows()
# TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API.
if __name__ == "__main__":
@@ -86,5 +92,11 @@ if __name__ == "__main__":
default=720,
help="Height of the captured images.",
)
parser.add_argument(
"--duration",
type=int,
default=20,
help="Duration in seconds for which the video stream should be captured.",
)
args = parser.parse_args()
display_and_save_video_stream(**vars(args))

View File

@@ -57,9 +57,15 @@ conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
#### 5. Install LeRobot with dependencies for the feetech motors:
#### 5. Install ffmpeg in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
conda install ffmpeg -c conda-forge
```
#### 6. Install LeRobot with dependencies for the feetech motors:
```bash
cd ~/lerobot && pip install -e ".[feetech]"
```
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
@@ -491,6 +497,9 @@ python lerobot/scripts/control_robot.py \
#### a. Teleop with displaying cameras
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=so100 \

View File

@@ -67,9 +67,15 @@ conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
#### 5. Install LeRobot with dependencies for the feetech motors:
#### 5. Install ffmpeg in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
conda install ffmpeg -c conda-forge
```
#### 6. Install LeRobot with dependencies for the feetech motors:
```bash
cd ~/lerobot && pip install -e ".[feetech]"
```
## C. Install LeRobot on laptop
@@ -108,9 +114,15 @@ conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
#### 5. Install LeRobot with dependencies for the feetech motors:
#### 5. Install ffmpeg in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
conda install ffmpeg -c conda-forge
```
#### 6. Install LeRobot with dependencies for the feetech motors:
```bash
cd ~/lerobot && pip install -e ".[feetech]"
```
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
@@ -393,6 +405,10 @@ python lerobot/scripts/control_robot.py \
```
# F. Teleoperate
> [!TIP]
> If you're using a Mac, you might need to give Terminal permission to access your keyboard. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
```bash
python lerobot/scripts/control_robot.py \
@@ -408,6 +424,8 @@ python lerobot/scripts/control_robot.py \
--control.fps=30
```
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. For the `--control.type=remote_robot` you will also need to set `--control.viewer_ip` and `--control.viewer_port`
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
| ---------- | ------------------ | ---------------------- |

View File

@@ -31,9 +31,15 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
5. Install LeRobot with dependencies for the feetech motors:
5. Install ffmpeg in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
conda install ffmpeg -c conda-forge
```
6. Install LeRobot with dependencies for the feetech motors:
```bash
cd ~/lerobot && pip install -e ".[feetech]"
```
## Configure the motors
@@ -212,6 +218,9 @@ python lerobot/scripts/control_robot.py \
**Teleop with displaying cameras**
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=moss \

View File

@@ -119,7 +119,7 @@ print(dataset.features[camera_key]["shape"])
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
camera_key: [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
@@ -143,6 +143,6 @@ dataloader = torch.utils.data.DataLoader(
for batch in dataloader:
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32, 5, c)
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
break

View File

@@ -18,7 +18,7 @@ training outputs directory. In the latter case, you might want to run examples/3
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
```bash
pip install --no-binary=av -e ".[pusht]"`
pip install -e ".[pusht]"
```
"""

View File

@@ -33,7 +33,7 @@ First, install the additional dependencies required for robots built with dynami
Using `pip`:
```bash
pip install --no-binary=av -e ".[dynamixel]"
pip install -e ".[dynamixel]"
```
Using `poetry`:
@@ -55,6 +55,9 @@ Finally, connect both arms to your computer via USB. Note that the USB doesn't p
Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=koch \
@@ -828,10 +831,10 @@ It contains:
Troubleshooting:
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
- or, install [Homebrew](https://brew.sh) and run `brew install ffmpeg` (it should be compiled with `libsvtav1`),
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform (check the version installed with `ffmpeg -encoders | grep libsvtav1`). If it isn't `ffmpeg 7.X` or lacks `libsvtav1` support, you can explicitly install `ffmpeg 7.X` using: `conda install ffmpeg=7.1.1 -c conda-forge`
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/koch_test) that you can obtain by running:

View File

@@ -43,14 +43,19 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
6. Install LeRobot with stretch dependencies:
6. When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[stretch]"
conda install ffmpeg -c conda-forge
```
7. Install LeRobot with stretch dependencies:
```bash
cd ~/lerobot && pip install -e ".[stretch]"
```
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
7. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
8. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
```bash
stretch_system_check.py
```
@@ -97,6 +102,8 @@ This is equivalent to running `stretch_robot_home.py`
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
Now try out teleoperation (see above documentation to learn about the gamepad controls):
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \

View File

@@ -30,9 +30,14 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
5. When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[dynamixel, intelrealsense]"
conda install ffmpeg -c conda-forge
```
6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
```bash
cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
```
## Teleoperate
@@ -43,6 +48,9 @@ Teleoperation consists in manually operating the leader arms to move the followe
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
By running the following code, you can start your first **SAFE** teleoperation:
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \

View File

@@ -1053,7 +1053,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [

View File

@@ -240,7 +240,7 @@ def load_episodes_stats(local_dir: Path) -> dict:
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[str, dict[str, np.ndarray]]:
return {ep_idx: stats for ep_idx in episodes}
return dict.fromkeys(episodes, stats)
def load_image_as_numpy(

View File

@@ -481,7 +481,7 @@ def convert_dataset(
# Tasks
if single_task:
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_path:

View File

@@ -13,7 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any
import einops
import gymnasium as gym
import numpy as np
import torch
from torch import Tensor
@@ -86,3 +90,38 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
policy_features[policy_key] = feature
return policy_features
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
first_type = type(env.envs[0]) # Get type of first env
return all(type(e) is first_type for e in env.envs) # Fast type check
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
with warnings.catch_warnings():
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
warnings.warn(
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
UserWarning,
stacklevel=2,
)
if not are_all_envs_same_type(env):
warnings.warn(
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
UserWarning,
stacklevel=2,
)
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
"""Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"):
observation["task"] = env.call("task_description")
elif hasattr(env.envs[0], "task"):
observation["task"] = env.call("task")
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
num_envs = observation[list(observation.keys())[0]].shape[0]
observation["task"] = ["" for _ in range(num_envs)]
return observation

View File

@@ -0,0 +1 @@
# Common mocks for robot devices and testing

View File

View File

@@ -0,0 +1,101 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cache
import numpy as np
CAP_V4L2 = 200
CAP_DSHOW = 700
CAP_AVFOUNDATION = 1200
CAP_ANY = -1
CAP_PROP_FPS = 5
CAP_PROP_FRAME_WIDTH = 3
CAP_PROP_FRAME_HEIGHT = 4
COLOR_RGB2BGR = 4
COLOR_BGR2RGB = 4
ROTATE_90_COUNTERCLOCKWISE = 2
ROTATE_90_CLOCKWISE = 0
ROTATE_180 = 1
@cache
def _generate_image(width: int, height: int):
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
def cvtColor(color_image, color_conversion): # noqa: N802
if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
return color_image[:, :, [2, 1, 0]]
else:
raise NotImplementedError(color_conversion)
def rotate(color_image, rotation):
if rotation is None:
return color_image
elif rotation == ROTATE_90_CLOCKWISE:
return np.rot90(color_image, k=1)
elif rotation == ROTATE_180:
return np.rot90(color_image, k=2)
elif rotation == ROTATE_90_COUNTERCLOCKWISE:
return np.rot90(color_image, k=3)
else:
raise NotImplementedError(rotation)
class VideoCapture:
def __init__(self, *args, **kwargs):
self._mock_dict = {
CAP_PROP_FPS: 30,
CAP_PROP_FRAME_WIDTH: 640,
CAP_PROP_FRAME_HEIGHT: 480,
}
self._is_opened = True
def isOpened(self): # noqa: N802
return self._is_opened
def set(self, propId: int, value: float) -> bool: # noqa: N803
if not self._is_opened:
raise RuntimeError("Camera is not opened")
self._mock_dict[propId] = value
return True
def get(self, propId: int) -> float: # noqa: N803
if not self._is_opened:
raise RuntimeError("Camera is not opened")
value = self._mock_dict[propId]
if value == 0:
if propId == CAP_PROP_FRAME_HEIGHT:
value = 480
elif propId == CAP_PROP_FRAME_WIDTH:
value = 640
return value
def read(self):
if not self._is_opened:
raise RuntimeError("Camera is not opened")
h = self.get(CAP_PROP_FRAME_HEIGHT)
w = self.get(CAP_PROP_FRAME_WIDTH)
ret = True
return ret, _generate_image(width=w, height=h)
def release(self):
self._is_opened = False
def __del__(self):
if self._is_opened:
self.release()

View File

@@ -0,0 +1,148 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import numpy as np
class stream(enum.Enum): # noqa: N801
color = 0
depth = 1
class format(enum.Enum): # noqa: N801
rgb8 = 0
z16 = 1
class config: # noqa: N801
def enable_device(self, device_id: str):
self.device_enabled = device_id
def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
self.stream_type = stream_type
# Overwrite default values when possible
self.width = 848 if width is None else width
self.height = 480 if height is None else height
self.color_format = format.rgb8 if color_format is None else color_format
self.fps = 30 if fps is None else fps
class RSColorProfile:
def __init__(self, config):
self.config = config
def fps(self):
return self.config.fps
def width(self):
return self.config.width
def height(self):
return self.config.height
class RSColorStream:
def __init__(self, config):
self.config = config
def as_video_stream_profile(self):
return RSColorProfile(self.config)
class RSProfile:
def __init__(self, config):
self.config = config
def get_stream(self, color_format):
del color_format # unused
return RSColorStream(self.config)
class pipeline: # noqa: N801
def __init__(self):
self.started = False
self.config = None
def start(self, config):
self.started = True
self.config = config
return RSProfile(self.config)
def stop(self):
if not self.started:
raise RuntimeError("You need to start the camera before stop.")
self.started = False
self.config = None
def wait_for_frames(self, timeout_ms=50000):
del timeout_ms # unused
return RSFrames(self.config)
class RSFrames:
def __init__(self, config):
self.config = config
def get_color_frame(self):
return RSColorFrame(self.config)
def get_depth_frame(self):
return RSDepthFrame(self.config)
class RSColorFrame:
def __init__(self, config):
self.config = config
def get_data(self):
data = np.ones((self.config.height, self.config.width, 3), dtype=np.uint8)
# Create a difference between rgb and bgr
data[:, :, 0] = 2
return data
class RSDepthFrame:
def __init__(self, config):
self.config = config
def get_data(self):
return np.ones((self.config.height, self.config.width), dtype=np.uint16)
class RSDevice:
def __init__(self):
pass
def get_info(self, camera_info) -> str:
del camera_info # unused
# return fake serial number
return "123456789"
class context: # noqa: N801
def __init__(self):
pass
def query_devices(self):
return [RSDevice()]
class camera_info: # noqa: N801
# fake name
name = "Intel RealSense D435I"
def __init__(self, serial_number):
del serial_number
pass

View File

@@ -0,0 +1 @@
# Mocks for motor modules

View File

@@ -0,0 +1,107 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
from the original classes and functions (e.g. return types might be None instead of boolean).
"""
# from dynamixel_sdk import COMM_SUCCESS
DEFAULT_BAUDRATE = 9_600
COMM_SUCCESS = 0 # tx or rx packet communication success
def convert_to_bytes(value, bytes):
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
# `convert_bytes_to_value`
del bytes # unused
return value
def get_default_motor_values(motor_index):
return {
# Key (int) are from X_SERIES_CONTROL_TABLE
7: motor_index, # ID
8: DEFAULT_BAUDRATE, # Baud_rate
10: 0, # Drive_Mode
64: 0, # Torque_Enable
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
# For other joints, 2560 will be autocorrected to be in calibration range
132: 2560, # Present_Position
}
class PortHandler:
def __init__(self, port):
self.port = port
# factory default baudrate
self.baudrate = DEFAULT_BAUDRATE
def openPort(self): # noqa: N802
return True
def closePort(self): # noqa: N802
pass
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
del timeout_ms # unused
def getBaudRate(self): # noqa: N802
return self.baudrate
def setBaudRate(self, baudrate): # noqa: N802
self.baudrate = baudrate
class PacketHandler:
def __init__(self, protocol_version):
del protocol_version # unused
# Use packet_handler.data to communicate across Read and Write
self.data = {}
class GroupSyncRead:
def __init__(self, port_handler, packet_handler, address, bytes):
self.packet_handler = packet_handler
def addParam(self, motor_index): # noqa: N802
# Initialize motor default values
if motor_index not in self.packet_handler.data:
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
def txRxPacket(self): # noqa: N802
return COMM_SUCCESS
def getData(self, index, address, bytes): # noqa: N802
return self.packet_handler.data[index][address]
class GroupSyncWrite:
def __init__(self, port_handler, packet_handler, address, bytes):
self.packet_handler = packet_handler
self.address = address
def addParam(self, index, data): # noqa: N802
# Initialize motor default values
if index not in self.packet_handler.data:
self.packet_handler.data[index] = get_default_motor_values(index)
self.changeParam(index, data)
def txPacket(self): # noqa: N802
return COMM_SUCCESS
def changeParam(self, index, data): # noqa: N802
self.packet_handler.data[index][self.address] = data

View File

@@ -0,0 +1,125 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
from the original classes and functions (e.g. return types might be None instead of boolean).
"""
# from dynamixel_sdk import COMM_SUCCESS
DEFAULT_BAUDRATE = 1_000_000
COMM_SUCCESS = 0 # tx or rx packet communication success
def convert_to_bytes(value, bytes):
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
# `convert_bytes_to_value`
del bytes # unused
return value
def get_default_motor_values(motor_index):
return {
# Key (int) are from SCS_SERIES_CONTROL_TABLE
5: motor_index, # ID
6: DEFAULT_BAUDRATE, # Baud_rate
10: 0, # Drive_Mode
21: 32, # P_Coefficient
22: 32, # D_Coefficient
23: 0, # I_Coefficient
40: 0, # Torque_Enable
41: 254, # Acceleration
31: -2047, # Offset
33: 0, # Mode
55: 1, # Lock
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
# For other joints, 2560 will be autocorrected to be in calibration range
56: 2560, # Present_Position
58: 0, # Present_Speed
69: 0, # Present_Current
85: 150, # Maximum_Acceleration
}
class PortHandler:
def __init__(self, port):
self.port = port
# factory default baudrate
self.baudrate = DEFAULT_BAUDRATE
self.ser = SerialMock()
def openPort(self): # noqa: N802
return True
def closePort(self): # noqa: N802
pass
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
del timeout_ms # unused
def getBaudRate(self): # noqa: N802
return self.baudrate
def setBaudRate(self, baudrate): # noqa: N802
self.baudrate = baudrate
class PacketHandler:
def __init__(self, protocol_version):
del protocol_version # unused
# Use packet_handler.data to communicate across Read and Write
self.data = {}
class GroupSyncRead:
def __init__(self, port_handler, packet_handler, address, bytes):
self.packet_handler = packet_handler
def addParam(self, motor_index): # noqa: N802
# Initialize motor default values
if motor_index not in self.packet_handler.data:
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
def txRxPacket(self): # noqa: N802
return COMM_SUCCESS
def getData(self, index, address, bytes): # noqa: N802
return self.packet_handler.data[index][address]
class GroupSyncWrite:
def __init__(self, port_handler, packet_handler, address, bytes):
self.packet_handler = packet_handler
self.address = address
def addParam(self, index, data): # noqa: N802
if index not in self.packet_handler.data:
self.packet_handler.data[index] = get_default_motor_values(index)
self.changeParam(index, data)
def txPacket(self): # noqa: N802
return COMM_SUCCESS
def changeParam(self, index, data): # noqa: N802
self.packet_handler.data[index][self.address] = data
class SerialMock:
def reset_output_buffer(self):
pass
def reset_input_buffer(self):
pass

View File

@@ -25,6 +25,7 @@ from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
@@ -54,6 +55,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "pi0fast":
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -69,6 +74,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

View File

@@ -24,7 +24,7 @@ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Install pi0 extra dependencies:
```bash
pip install --no-binary=av -e ".[pi0]"
pip install -e ".[pi0]"
```
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):

View File

@@ -0,0 +1,136 @@
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@PreTrainedConfig.register_subclass("pi0fast")
@dataclass
class PI0FASTConfig(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 10
n_action_steps: int = 5
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32 # 32
max_action_dim: int = 32 # 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)
interpolate_like_pi: bool = False
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Projector
proj_width: int = 1024
# Decoding
max_decoding_steps: int = 256
fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
max_input_seq_len: int = 256 # 512
# Utils
use_cache: bool = True
# Frozen parameters
freeze_vision_encoder: bool = True
freeze_lm_head: bool = True
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
checkpoint_path: str = None
padding_side: str = "right"
precision: str = "bfloat16"
grad_clip_norm: float = 1
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
relaxed_action_decoding: bool = True
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,973 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
[Paper](https://arxiv.org/abs/2501.09747)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
```bash
python lerobot/scripts/train.py \
--policy.path=lerobot/pi0fast_base \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of training the pi0+FAST neural network with from scratch:
```bash
python lerobot/scripts/train.py \
--policy.type=pi0fast \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of using the pi0 pretrained model outside LeRobot training framework:
```python
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
```
"""
from collections import deque
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from scipy.fft import idct
from torch import Tensor, nn
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
PRECISION = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class PI0FASTPolicy(PreTrainedPolicy):
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
config_class = PI0FASTConfig
name = "pi0fast"
def __init__(
self,
config: PI0FASTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def get_optim_params(self) -> dict:
return self.parameters()
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model.generate_actions(batch)
actions = actions[:, : self.config.n_action_steps]
original_action_dim = self.config.action_feature.shape[
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss_dict = self.model.forward(batch)
return loss_dict["loss"], loss_dict
def block_causal_update_causal_mask(
attention_mask,
token_type_ids=None,
past_key_values=None,
cache_position=None,
input_tensor=None,
attn_implementation: str = "eager",
dtype: torch.dtype = "float32",
):
"""
Update the causal mask during training and generation. It can be customized to different attention masks.
"""
if attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(dtype).min
if input_tensor is None:
input_tensor = attention_mask
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache or isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
# Handle precomputed attention masks
if attention_mask is not None and attention_mask.dim() == 4:
return attention_mask
# Causal mask initialization
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
# Standard causal masking (triu ensures tokens can only attend to past)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
# Apply block causal mask
if token_type_ids is not None:
token_type_ids = token_type_ids.to(causal_mask.device).bool()
cumsum = torch.cumsum(token_type_ids, dim=1)
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
# Combine causal_mask with block-wise attention mask
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
causal_mask = causal_mask[:, None, :, :]
else:
# Apply past cache position constraint
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
else:
# Apply past cache position constraint
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
mask_length = attention_mask.shape[-1]
# Apply padding mask
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def prepare_inputs_for_generation(
# self,
input_ids,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
pixel_values=None,
attention_mask=None,
token_type_ids=None,
use_cache=True,
num_logits_to_keep=None,
labels=None,
self=None,
**kwargs,
):
# create block causal attention
if cache_position[0] > 0 and input_ids.shape[1] > 0:
input_tensor = input_ids[:, -1:]
new_positions = (
torch.ones(
(position_ids.shape[0], input_ids.shape[1]),
dtype=position_ids.dtype,
device=position_ids.device,
).cumsum(-1)
+ position_ids[:, -1:]
)
position_ids = torch.cat([position_ids, new_positions], dim=-1)
else:
input_tensor = inputs_embeds
attention_mask = block_causal_update_causal_mask(
attention_mask=attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
input_tensor=input_tensor,
token_type_ids=token_type_ids,
dtype=self.dtype,
attn_implementation=self.config.text_config._attn_implementation,
)
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=use_cache,
num_logits_to_keep=num_logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
# Position_ids in Paligemma are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
class PI0FAST(nn.Module):
def __init__(self, config: PI0FASTConfig):
super().__init__()
self.config = config
# TODO: move tokenizers in Policy
fast_tokenizer_path = "physical-intelligence/fast"
pi0_paligemma_path = "google/paligemma-3b-pt-224"
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
self.fast_skip_tokens = self.config.fast_skip_tokens
self.max_input_seq_len = self.config.max_input_seq_len
self.action_horizon = self.config.chunk_size
self.action_dim = self.config.action_feature.shape[
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
precision = config.precision
torch_precision = PRECISION.get(precision, torch.float32)
self.pad_token_id = (
self.paligemma_tokenizer.pad_token_id
if hasattr(self.paligemma_tokenizer, "pad_token_id")
else self.paligemma_tokenizer.eos_token_id
)
paligemma_config = CONFIG_MAPPING["paligemma"](
transformers_version="4.48.1",
_vocab_size=257152,
bos_token_id=2,
eos_token_id=1,
hidden_size=2048,
image_token_index=257152,
model_type="paligemma",
pad_token_id=0,
projection_dim=2048,
text_config={
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2048,
"intermediate_size": 16384,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_image_tokens": 256,
"num_key_value_heads": 1,
"torch_dtype": precision,
"vocab_size": 257152,
"_attn_implementation": "eager",
},
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": 256,
"patch_size": 14,
"projection_dim": 2048,
"projector_hidden_act": "gelu_pytorch_tanh",
"torch_dtype": precision,
"vision_use_head": False,
},
)
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
self.pi0_paligemma.prepare_inputs_for_generation = partial(
prepare_inputs_for_generation, self=self.pi0_paligemma
)
# change important stuff in bf16
params_to_change_dtype = [
"language_model",
"vision_tower",
"multi_modal",
]
for name, param in self.pi0_paligemma.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch_precision)
self.set_requires_grad()
self.image_keys = self.config.image_features.keys()
self.ignore_index = self.pi0_paligemma.config.ignore_index
self.padding_side = self.config.padding_side
def set_requires_grad(self):
if self.config.freeze_vision_encoder:
self.pi0_paligemma.vision_tower.eval()
for params in self.pi0_paligemma.vision_tower.parameters():
params.requires_grad = False
# To avoid unused params issue with distributed training
if self.config.freeze_lm_head:
for name, params in self.pi0_paligemma.named_parameters():
if "embed_tokens" in name: # lm heads and embedding layer are tied
params.requires_grad = False
def embed_tokens(self, tokens: torch.Tensor):
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
def prepare_inputs_for_generation(self, *args, **kwargs):
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
def prepare_images(self, batch):
"""Preprocess LeRobot batch into Pi0 inputs"""
images = []
img_masks = []
present_img_keys = [key for key in self.image_keys if key in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
num_empty_cameras = 0
for key in self.image_keys:
if key in present_img_keys:
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(
img,
*self.config.resize_imgs_with_padding,
pad_value=0,
interpolate_like_pi=self.config.interpolate_like_pi,
)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
else:
if num_empty_cameras >= self.config.empty_cameras:
continue
img = torch.ones_like(img) * -1
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
num_empty_cameras += 1
images.append(img)
img_masks.append(mask)
return images, img_masks
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
return out
def fast_tokenizer_wrapper(self, actions_norm):
"""
A wrapper for self.fast_tokenizer that ensures batch processing,
conversion to PyTorch tensors, and returns a dictionary without padding.
"""
batch_tokens = self.fast_tokenizer(actions_norm)
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
return fast_out
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
# Compute cumulative sum mask
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
# Suffix block (everything after prefix_len)
suffix_mask = cumsum_mask > prefix_len
token_type_ids = suffix_mask
return token_type_ids
def create_input_tokens(self, state, lang_text, actions=None):
bsize = state.shape[0]
device = state.device
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
discretized = torch.bucketize(state, bins) - 1
discretized = discretized[:, :32]
prefix_texts = []
state_text = []
for txt, disc in zip(lang_text, discretized, strict=False):
cleaned = txt.lower().strip().replace("_", " ")
state_str = " ".join(str(val.item()) for val in disc)
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
state_text.append(f"State: {state_str};\n")
prefix_out = self.paligemma_tokenizer(
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
)
prefix_ids = prefix_out["input_ids"].to(device)
prefix_mask = prefix_out["attention_mask"].to(device)
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
if actions is not None:
actions_norm = self.normalize_actions(actions)
actions_pad = F.pad(
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
)[:, :, : self.config.max_action_dim]
fast_out = self.fast_tokenizer_wrapper(
actions_pad.cpu(),
)
act_ids = fast_out["input_ids"]
act_mask = fast_out["attention_mask"].to(device)
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
# Replace action with 0 to pad tokens
act_ids = torch.where(
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
self.pad_token_id,
act_ids,
)
eos_token = torch.tensor(
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
).expand(bsize, -1)
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
act_mask = act_mask.to(device)
else:
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
# Use tokenizer pad function
padded_output = self.paligemma_tokenizer.pad(
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
)
padded_mask = padded_output["attention_mask"]
# define tensor of padding lengths
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
padded_output["padded_mask"] = padded_output.pop("attention_mask")
padded_output["attention_mask"] = att_mask
# loss is computed not on prefix, and not on padding
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
padded_output["token_type_ids"] = token_type_ids
return padded_output
def shift_padding_side(
self,
tokens: torch.Tensor,
ar_mask: torch.Tensor,
padding_mask: torch.Tensor,
loss_mask: torch.Tensor,
targets: torch.Tensor,
token_type_ids: torch.Tensor,
padding_side: str = "right",
) -> tuple[torch.Tensor]:
if padding_side not in ["right", "left"]:
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
new_tokens = torch.empty_like(tokens)
new_ar_masks = torch.empty_like(ar_mask)
new_padding_mask = torch.empty_like(padding_mask)
new_loss_mask = torch.empty_like(loss_mask)
new_targets = torch.empty_like(targets)
new_token_type_ids = torch.empty_like(token_type_ids)
batch_size = tokens.shape[0]
for i in range(batch_size):
padding_indices = torch.where(padding_mask[i] == 0)[0]
non_padding_indices = torch.where(padding_mask[i] == 1)[0]
if padding_side == "left":
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
else:
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
new_tokens[i] = tokens[i].index_select(0, new_indices)
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
new_targets[i] = targets[i].index_select(0, new_indices)
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
def forward(self, batch: dict[str, Tensor]):
device = batch[OBS_ROBOT].device
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(
state=batch[OBS_ROBOT],
lang_text=batch["task"],
actions=batch[ACTION],
)
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,
padded_outs["input_ids"],
padded_outs["padded_mask"],
padded_outs["attention_mask"],
padded_outs["loss_mask"],
padded_outs["token_type_ids"],
padding_side=self.padding_side,
)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
token_type_ids = token_type_ids.to(dtype=torch.int64)
past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
pad_masks = block_causal_update_causal_mask(
attention_mask=pad_masks,
past_key_values=None,
cache_position=cache_position,
input_tensor=embs,
token_type_ids=token_type_ids,
dtype=self.pi0_paligemma.dtype,
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
)
outputs = self.pi0_paligemma.forward(
input_ids=None,
token_type_ids=None,
attention_mask=pad_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=embs,
use_cache=False,
labels=None,
)
logits = outputs.logits
loss_fct = nn.CrossEntropyLoss(reduction="none")
# Shift left for next-step prediction
logits = logits[:, :-1, :]
targets = targets[:, 1:].to(device) # Shift targets
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
# Compute per-token loss
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
# Apply loss mask
token_loss = token_loss * loss_mask.reshape(-1)
# Compute final loss
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
# Return loss dictionary
loss_dict = {"ce_loss": loss.item(), "loss": loss}
return loss_dict
def decode_actions_with_fast(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
relaxed_decoding: bool = True,
) -> np.array:
"""
Adapt original decoding in FAST to always return actions instead of zeros.
"""
self.time_horizon = (
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
)
self.action_dim = (
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
)
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert self.time_horizon is not None and self.action_dim is not None, (
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
)
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
if relaxed_decoding:
# Expected sequence length
expected_seq_len = self.time_horizon * self.action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0]
# Apply truncation if too long
if diff < 0:
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
# Apply padding if too short
elif diff > 0:
decoded_dct_coeff = np.pad(
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
)
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert decoded_dct_coeff.shape == (
self.time_horizon,
self.action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
"""
Extracts actions from predicted output tokens using the FAST model.
Args:
tokens (torch.Tensor): The input tensor of tokenized outputs.
action_horizon (int): The number of timesteps for actions.
action_dim (int): The dimensionality of each action.
Returns:
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
"""
# Decode predicted output tokens
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
cleaned_tokens = [
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
for tokens_sequence in decoded_tokens
]
raw_action_tokens = [
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
for sample_tokens in cleaned_tokens
] # something like this should be robust #looks good
action_tokens = [
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
]
# returns the tensor of decoded actions per sample in a list
decoded_actions = [
torch.tensor(
self.decode_actions_with_fast(
tok.tolist(),
time_horizon=action_horizon,
action_dim=action_dim,
relaxed_decoding=self.config.relaxed_action_decoding,
),
device=tokens.device,
).squeeze(0)
for tok in action_tokens
]
return torch.stack(
decoded_actions,
dim=0,
)
def generate_actions(self, batch: dict[str, Tensor]):
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,
padded_outs["input_ids"],
padded_outs["padded_mask"],
padded_outs["attention_mask"],
padded_outs["loss_mask"],
padded_outs["token_type_ids"],
padding_side="left",
)
token_type_ids = token_type_ids.to(dtype=torch.int64)
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
output_tokens = self.pi0_paligemma.generate(
input_ids=None,
attention_mask=pad_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=embs,
use_cache=self.config.use_cache,
max_new_tokens=self.config.max_decoding_steps,
do_sample=False,
num_beams=1,
token_type_ids=token_type_ids,
)
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
return actions
def embed_image(self, image: torch.Tensor):
return self.pi0_paligemma.get_image_features(image)
def embed_inputs(
self,
images,
img_masks,
tokens,
pad_mask,
ar_mask,
loss_mask,
token_type_ids,
padding_side: str = "right",
):
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
# images are a list of same size
# vectorizing everything!
device = images[0].device
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
all_images = torch.stack(images, dim=1).to(device)
b, n, c, h, w = all_images.shape
all_images = all_images.view(b * n, c, h, w)
embedded = self.embed_image(all_images).to(device)
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
m = b_n // b # Compute the number of images per sample dynamically
# Reshape dynamically
embedded = embedded.view(b, m, p, image_embedding_dim)
tokens_embs = self.embed_tokens(tokens.to(device))
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
num_img_emb = embedded.shape[2]
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
image_target_tokens = (
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
).reshape(b, -1)
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
# Shift pad tokens to the left (.generate()) or right (.train())
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
)
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
if interpolate_like_pi:
img = (img * 255.0).to(dtype=torch.uint8)
img = img.permute(0, 2, 3, 1)
original_device = img.device
img = img.to(device="cpu").numpy()
imgs = []
for sub_img in img:
sub_img = Image.fromarray(sub_img)
resized_img = sub_img.resize((resized_width, resized_height), resample=2)
resized_img = torch.from_numpy(np.array(resized_img))
imgs.append(resized_img)
img = torch.stack(imgs, dim=0)
img = img.permute(0, 3, 1, 2)
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
else:
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img

View File

@@ -122,7 +122,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# When the action queue is depleted, populate it again by querying the policy.
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
# Remove the time dimensions as it is not handled yet.
for key in batch:

View File

@@ -48,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
connected to the computer.
"""
if mock:
import tests.cameras.mock_pyrealsense2 as rs
import lerobot.common.mocks.cameras.mock_pyrealsense2 as rs
else:
import pyrealsense2 as rs
@@ -100,7 +100,7 @@ def save_images_from_cameras(
serial_numbers = [cam["serial_number"] for cam in camera_infos]
if mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2
@@ -253,7 +253,7 @@ class IntelRealSenseCamera:
self.logs = {}
if self.mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2
@@ -287,7 +287,7 @@ class IntelRealSenseCamera:
)
if self.mock:
import tests.cameras.mock_pyrealsense2 as rs
import lerobot.common.mocks.cameras.mock_pyrealsense2 as rs
else:
import pyrealsense2 as rs
@@ -375,7 +375,7 @@ class IntelRealSenseCamera:
)
if self.mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2

View File

@@ -80,7 +80,7 @@ def _find_cameras(
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
) -> list[int | str]:
if mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2
@@ -269,7 +269,7 @@ class OpenCVCamera:
self.logs = {}
if self.mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2
@@ -286,7 +286,7 @@ class OpenCVCamera:
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
if self.mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2
@@ -398,7 +398,7 @@ class OpenCVCamera:
# so we convert the image color from BGR to RGB.
if requested_color_mode == "rgb":
if self.mock:
import tests.cameras.mock_cv2 as cv2
import lerobot.common.mocks.cameras.mock_cv2 as cv2
else:
import cv2

View File

@@ -41,7 +41,7 @@ class TeleoperateControlConfig(ControlConfig):
fps: int | None = None
teleop_time_s: float | None = None
# Display all cameras on screen
display_cameras: bool = True
display_data: bool = False
@ControlConfig.register_subclass("record")
@@ -82,7 +82,7 @@ class RecordControlConfig(ControlConfig):
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Display all cameras on screen
display_cameras: bool = True
display_data: bool = False
# Use vocal synthesis to read events.
play_sounds: bool = True
# Resume recording on an existing dataset.
@@ -116,6 +116,11 @@ class ReplayControlConfig(ControlConfig):
@dataclass
class RemoteRobotConfig(ControlConfig):
log_interval: int = 100
# Display all cameras on screen
display_data: bool = False
# Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp)
viewer_ip: str | None = None
viewer_port: str | None = None
@dataclass

View File

@@ -24,7 +24,7 @@ from contextlib import nullcontext
from copy import copy
from functools import cache
import cv2
import rerun as rr
import torch
from deepdiff import DeepDiff
from termcolor import colored
@@ -174,13 +174,13 @@ def warmup_record(
events,
enable_teleoperation,
warmup_time_s,
display_cameras,
display_data,
fps,
):
control_loop(
robot=robot,
control_time_s=warmup_time_s,
display_cameras=display_cameras,
display_data=display_data,
events=events,
fps=fps,
teleoperate=enable_teleoperation,
@@ -192,7 +192,7 @@ def record_episode(
dataset,
events,
episode_time_s,
display_cameras,
display_data,
policy,
fps,
single_task,
@@ -200,7 +200,7 @@ def record_episode(
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
display_data=display_data,
dataset=dataset,
events=events,
policy=policy,
@@ -215,7 +215,7 @@ def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
display_data=False,
dataset: LeRobotDataset | None = None,
events=None,
policy: PreTrainedPolicy = None,
@@ -264,11 +264,15 @@ def control_loop(
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
if display_cameras and not is_headless():
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
for k, v in action.items():
for i, vv in enumerate(v):
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
rr.log(key, rr.Image(observation[key].numpy()), static=True)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
@@ -297,15 +301,11 @@ def reset_environment(robot, events, reset_time_s, fps):
)
def stop_recording(robot, listener, display_cameras):
def stop_recording(robot, listener, display_data):
robot.disconnect()
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
if not is_headless() and listener is not None:
listener.stop()
def sanity_check_dataset_name(repo_id, policy_cfg):

View File

@@ -332,7 +332,7 @@ class DynamixelMotorsBus:
)
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -356,7 +356,7 @@ class DynamixelMotorsBus:
def reconnect(self):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -646,7 +646,7 @@ class DynamixelMotorsBus:
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -691,7 +691,7 @@ class DynamixelMotorsBus:
start_time = time.perf_counter()
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -757,7 +757,7 @@ class DynamixelMotorsBus:
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -793,7 +793,7 @@ class DynamixelMotorsBus:
start_time = time.perf_counter()
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl

View File

@@ -313,7 +313,7 @@ class FeetechMotorsBus:
)
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -337,7 +337,7 @@ class FeetechMotorsBus:
def reconnect(self):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -664,7 +664,7 @@ class FeetechMotorsBus:
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -702,7 +702,7 @@ class FeetechMotorsBus:
def read(self, data_name, motor_names: str | list[str] | None = None):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -782,7 +782,7 @@ class FeetechMotorsBus:
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -818,7 +818,7 @@ class FeetechMotorsBus:
start_time = time.perf_counter()
if self.mock:
import tests.motors.mock_scservo_sdk as scs
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs

View File

@@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091",
port="/dev/tty.usbmodem58760429101",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
port="/dev/tty.usbmodem58760435821",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],

View File

@@ -94,7 +94,7 @@ class MetricsTracker:
metrics: dict[str, AverageMeter],
initial_step: int = 0,
):
self.__dict__.update({k: None for k in self.__keys__})
self.__dict__.update(dict.fromkeys(self.__keys__))
self._batch_size = batch_size
self._num_frames = num_frames
self._avg_samples_per_ep = num_frames / num_episodes

View File

@@ -135,15 +135,19 @@ python lerobot/scripts/control_robot.py \
"""
import logging
import os
import time
from dataclasses import asdict
from pprint import pformat
import rerun as rr
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
ControlConfig,
ControlPipelineConfig,
RecordControlConfig,
RemoteRobotConfig,
@@ -153,6 +157,7 @@ from lerobot.common.robot_devices.control_configs import (
from lerobot.common.robot_devices.control_utils import (
control_loop,
init_keyboard_listener,
is_headless,
log_control_info,
record_episode,
reset_environment,
@@ -232,7 +237,7 @@ def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
control_time_s=cfg.teleop_time_s,
fps=cfg.fps,
teleoperate=True,
display_cameras=cfg.display_cameras,
display_data=cfg.display_data,
)
@@ -280,7 +285,7 @@ def record(
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", cfg.play_sounds)
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps)
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
@@ -296,7 +301,7 @@ def record(
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
display_data=cfg.display_data,
policy=policy,
fps=cfg.fps,
single_task=cfg.single_task,
@@ -326,7 +331,7 @@ def record(
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
stop_recording(robot, listener, cfg.display_data)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
@@ -363,6 +368,40 @@ def replay(
log_control_info(robot, dt_s, fps=cfg.fps)
def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None:
"""Initializes the Rerun SDK for visualizing the control loop.
Args:
control_config: Configuration determining data display and robot type.
session_name: Rerun session name. Defaults to "lerobot_control_loop".
Raises:
ValueError: If viewer IP is missing for non-remote configurations with display enabled.
"""
if (control_config.display_data and not is_headless()) or (
control_config.display_data and isinstance(control_config, RemoteRobotConfig)
):
# Configure Rerun flush batch size default to 8KB if not set
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
# Initialize Rerun based on configuration
rr.init(session_name)
if isinstance(control_config, RemoteRobotConfig):
viewer_ip = control_config.viewer_ip
viewer_port = control_config.viewer_port
if not viewer_ip or not viewer_port:
raise ValueError(
"Viewer IP & Port are required for remote config. Set via config file/CLI or disable control_config.display_data."
)
logging.info(f"Connecting to viewer at {viewer_ip}:{viewer_port}")
rr.connect_tcp(f"{viewer_ip}:{viewer_port}")
else:
# Get memory limit for rerun viewer parameters
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
rr.spawn(memory_limit=memory_limit)
@parser.wrap()
def control_robot(cfg: ControlPipelineConfig):
init_logging()
@@ -370,17 +409,22 @@ def control_robot(cfg: ControlPipelineConfig):
robot = make_robot_from_config(cfg.robot)
# TODO(Steven): Blueprint for fixed window size
if isinstance(cfg.control, CalibrateControlConfig):
calibrate(robot, cfg.control)
elif isinstance(cfg.control, TeleoperateControlConfig):
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop")
teleoperate(robot, cfg.control)
elif isinstance(cfg.control, RecordControlConfig):
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record")
record(robot, cfg.control)
elif isinstance(cfg.control, ReplayControlConfig):
replay(robot, cfg.control)
elif isinstance(cfg.control, RemoteRobotConfig):
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote")
run_lekiwi(cfg.robot)
if robot.is_connected:

View File

@@ -66,7 +66,7 @@ from torch import Tensor, nn
from tqdm import trange
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
@@ -124,7 +124,6 @@ def rollout(
# Reset the policy and environments.
policy.reset()
observation, info = env.reset(seed=seeds)
if render_callback is not None:
render_callback(env)
@@ -145,6 +144,7 @@ def rollout(
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
leave=False,
)
check_env_attributes_and_types(env)
while not np.all(done):
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
@@ -155,6 +155,10 @@ def rollout(
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
with torch.inference_mode():
action = policy.select_action(observation)

View File

@@ -0,0 +1,53 @@
// fmt: off
// flake8: noqa
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package async_inference;
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc StreamActions(Empty) returns (stream Action);
rpc Ready(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1;
bytes data = 2;
}
message Action {
// sent by remote Policy, to Robot
TransferState transfer_state = 1;
bytes data = 2;
}
message Empty {}

View File

@@ -0,0 +1,46 @@
# fmt: off
# flake8: noqa
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: async_inference.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'async_inference.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x06\x41\x63tion\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xd9\x01\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12\x42\n\rStreamActions\x12\x16.async_inference.Empty\x1a\x17.async_inference.Action0\x01\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=216
_globals['_TRANSFERSTATE']._serialized_end=312
_globals['_OBSERVATION']._serialized_start=42
_globals['_OBSERVATION']._serialized_end=125
_globals['_ACTION']._serialized_start=127
_globals['_ACTION']._serialized_end=205
_globals['_EMPTY']._serialized_start=207
_globals['_EMPTY']._serialized_end=214
_globals['_ASYNCINFERENCE']._serialized_start=315
_globals['_ASYNCINFERENCE']._serialized_end=532
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,193 @@
# fmt: off
# flake8: noqa
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import async_inference_pb2 as async__inference__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/async_inference.AsyncInference/SendObservations',
request_serializer=async__inference__pb2.Observation.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.StreamActions = channel.unary_stream(
'/async_inference.AsyncInference/StreamActions',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Action.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/async_inference.AsyncInference/Ready',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StreamActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=async__inference__pb2.Observation.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'StreamActions': grpc.unary_stream_rpc_method_handler(
servicer.StreamActions,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Action.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'async_inference.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/async_inference.AsyncInference/SendObservations',
async__inference__pb2.Observation.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def StreamActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/async_inference.AsyncInference/StreamActions',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Action.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Ready',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -0,0 +1,199 @@
import itertools
import pickle # nosec
import time
from concurrent import futures
from queue import Queue
from typing import Generator, List, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import torch
from datasets import load_dataset
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.scripts.server.robot_client import TimedAction, TimedObservation, environment_dt
inference_latency = 1 / 3
idle_wait = 0.1
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def __init__(self):
# TODO: Add device specification for policy inference at init
self.device = "mps"
start = time.time()
self.policy = ACTPolicy.from_pretrained("fracapuano/act_so100_test")
self.policy.to(self.device)
end = time.time()
print(f"Time taken to put policy on {self.device}: {end - start} seconds")
# Initialize dataset action generator
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset())
self._setup_server()
self.actions_per_chunk = 20
self.actions_overlap = 10
def _setup_server(self) -> None:
"""Flushes server state when new client connects."""
# only running inference on the latest observation received by the server
self.observation_queue = Queue(maxsize=1)
def Ready(self, request, context): # noqa: N802
self._setup_server()
print("Client connected and ready")
return async_inference_pb2.Empty()
def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client"""
# client_id = context.peer()
# print(f"Receiving observations from {client_id}")
for observation in request_iterator:
timed_observation = pickle.loads(observation.data) # nosec
# If queue is full, get the old observation to make room
if self.observation_queue.full():
# pops from queue
_ = self.observation_queue.get_nowait()
# Now put the new observation (never blocks as queue is non-full here)
self.observation_queue.put(timed_observation)
print("Received observation no: ", timed_observation.get_timestep())
return async_inference_pb2.Empty()
def StreamActions(self, request, context): # noqa: N802
"""Stream actions to the robot client"""
# client_id = context.peer()
# print(f"Client {client_id} connected for action streaming")
# Generate action based on the most recent observation and its timestep
obs = self.observation_queue.get()
print("Running inference for timestep: ", obs.get_timestep())
if obs:
yield self._predict_action_chunk(obs)
else:
print("No observation in queue yet!")
time.sleep(idle_wait)
return async_inference_pb2.Empty()
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
"""Turn a chunk of actions into a list of TimedAction instances,
with the first action corresponding to t_0 and the rest corresponding to
t_0 + i*environment_dt for i in range(len(action_chunk))
"""
return [
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk)
]
@torch.no_grad()
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
"""Predict an action based on the observation"""
self.policy.eval()
observation = {}
for k, v in observation_t.get_observation().items():
if "image" in k:
observation[k] = v.permute(2, 0, 1).unsqueeze(0).to(self.device)
else:
observation[k] = v.unsqueeze(0).to(self.device)
# Remove batch dimension
action_tensor = self.policy.select_action(observation).squeeze(0)
if action_tensor.dim() == 1:
# No chunk dimension, so repeat action to create a (dummy) chunk of actions
action_tensor = action_tensor.cpu().repeat(self.actions_per_chunk, 1)
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
)
action_bytes = pickle.dumps(action_chunk) # nosec
# Create and return the Action message
action = async_inference_pb2.Action(transfer_state=observation_t.transfer_state, data=action_bytes)
time.sleep(inference_latency) # slow action generation, emulates inference time (ACT is very fast)
return action
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]:
"""Stream chunks of actions from a prerecorded dataset.
Returns:
Generator that yields chunks of actions from the dataset
"""
dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch")
# 1. Select the action column only, where you will find tensors with 6 elements
actions = dataset["action"]
action_indices = torch.arange(len(actions))
# 2. Chunk the iterable of tensors into chunks with 10 elements each
# sending only first element for debugging
indices_chunks = action_indices.unfold(
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap
)
for idx_chunk in indices_chunks:
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :]
def _read_action_chunk(self, observation: Optional[TimedObservation] = None):
"""Dummy function for predicting action chunk given observation.
Instead of computing actions on-the-fly, this method streams
actions from a prerecorded dataset.
"""
import warnings
warnings.warn(
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
)
if not observation:
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0)
transfer_state = 0
else:
transfer_state = observation.transfer_state
# Get chunk of actions from the generator
actions_chunk = next(self.action_generator)
# Return a list of TimedActions, with timestamps starting from the observation timestamp
action_data = self._time_action_chunk(
observation.get_timestamp(), actions_chunk, observation.get_timestep()
)
action_bytes = pickle.dumps(action_data) # nosec
# Create and return the Action message
action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_bytes)
time.sleep(inference_latency) # slow action generation, emulates inference time
return action
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(PolicyServer(), server)
server.add_insecure_port("[::]:50051")
server.start()
print("PolicyServer started on port 50051")
try:
while True:
time.sleep(86400) # Sleep for a day, or until interrupted
except KeyboardInterrupt:
server.stop(0)
print("Server stopped")
if __name__ == "__main__":
serve()

View File

@@ -0,0 +1,357 @@
import pickle # nosec
import threading
import time
from queue import Empty, Queue
from typing import Any, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import torch
from lerobot.common.robot_devices.robots.utils import make_robot
environment_dt = 1 / 30
idle_wait = 0.1
class TimedData:
def __init__(self, timestamp: float, data: Any, timestep: int):
"""Initialize a TimedData object.
Args:
timestamp: Unix timestamp relative to data's creation.
data: The actual data to wrap a timestamp around.
"""
self.timestamp = timestamp
self.data = data
self.timestep = timestep
def get_data(self):
return self.data
def get_timestamp(self):
return self.timestamp
def get_timestep(self):
return self.timestep
class TimedAction(TimedData):
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
def get_action(self):
return self.get_data()
class TimedObservation(TimedData):
def __init__(
self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0
):
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
self.transfer_state = transfer_state
def get_observation(self):
return self.get_data()
class RobotClient:
def __init__(
self,
# cfg: RobotConfig,
server_address="localhost:50051",
use_robot=True,
):
self.channel = grpc.insecure_channel(server_address)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.running = False
self.first_observation_sent = False
self.latest_action = 0
self.action_chunk_size = 20
self.action_queue = Queue()
self.start_barrier = threading.Barrier(3)
# Create a lock for robot access
self.robot_lock = threading.Lock()
self.use_robot = use_robot
if self.use_robot:
self.robot = make_robot("so100")
self.robot.connect()
time.sleep(idle_wait) # sleep waiting for cameras to activate
print("Robot connected")
self.robot_reading = True
def timestamps(self):
"""Get the timestamps of the actions in the queue"""
return sorted([action.get_timestep() for action in self.action_queue.queue])
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
self.stub.Ready(async_inference_pb2.Empty())
print("Connected to policy server")
self.running = True
return True
except grpc.RpcError as e:
print(f"Failed to connect to policy server: {e}")
return False
def stop(self):
"""Stop the robot client"""
self.running = False
if self.use_robot and hasattr(self, "robot"):
self.robot.disconnect()
self.channel.close()
def send_observation(
self,
obs: TimedObservation,
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running:
print("Client not running")
return False
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
observation_bytes = pickle.dumps(obs)
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
try:
_ = self.stub.SendObservations(iter([observation]))
if transfer_state == async_inference_pb2.TRANSFER_BEGIN:
self.first_observation_sent = True
return True
except grpc.RpcError as e:
print(f"Error sending observation: {e}")
return False
def _validate_action(self, action: TimedAction):
"""Received actions are keps only when they have been produced for now or later, never before"""
return not action.get_timestamp() < self.latest_action
def _validate_action_chunk(self, actions: list[TimedAction]):
assert len(actions) == self.action_chunk_size, (
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
)
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
return True
def _inspect_action_queue(self):
print("Queue size: ", self.action_queue.qsize())
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
def _clear_queue(self):
"""Clear the existing queue"""
while not self.action_queue.empty():
try:
self.action_queue.get_nowait()
except Empty:
break
def _fill_action_queue(self, actions: list[TimedAction]):
"""Fill the action queue with incoming valid actions"""
for action in actions:
if self._validate_action(action):
self.action_queue.put(action)
def _update_action_queue(self, actions: list[TimedAction]):
"""Aggregate incoming actions into the action queue.
Raises NotImplementedError as this is not implemented yet.
Args:
actions: List of TimedAction instances to queue
"""
# TODO: Implement this
raise NotImplementedError("Not implemented")
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
"""Clear the existing queue and fill it with new actions.
This is a higher-level function that combines clearing and filling operations.
Args:
actions: List of TimedAction instances to queue
"""
print("*** Current latest action: ", self.latest_action, "***")
print("\t**** Current queue content ****: ")
self._inspect_action_queue()
print("\t*** Incoming actions ****: ")
print([a.get_timestep() for a in actions])
self._clear_queue()
self._fill_action_queue(actions)
print("\t*** Queue after clearing and filling ****: ")
self._inspect_action_queue()
def receive_actions(self):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
print("Action receiving thread starting")
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
# Deserialize bytes back into list[TimedAction]
timed_actions = pickle.loads(actions_chunk.data) # nosec
# strategy for queue composition is specified in the method
self._clear_and_fill_action_queue(timed_actions)
except grpc.RpcError as e:
print(f"Error receiving actions: {e}")
time.sleep(idle_wait) # Avoid tight loop on error
def _get_next_action(self) -> Optional[TimedAction]:
"""Get the next action from the queue"""
try:
action = self.action_queue.get_nowait()
return action
except Empty:
return None
def execute_actions(self):
"""Continuously execute actions from the queue"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
print("Action execution thread starting")
while self.running:
# Get the next action from the queue
time.sleep(environment_dt)
timed_action = self._get_next_action()
if timed_action is not None:
# self.latest_action = timed_action.get_timestep()
self.latest_action = timed_action.get_timestamp()
# Convert action to tensor and send to robot
if self.use_robot:
# Acquire lock before accessing the robot
if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
try:
self.robot.send_action(timed_action.get_action())
finally:
# Always release the lock in a finally block to ensure it's released
self.robot_lock.release()
else:
print("Could not acquire robot lock for action execution, retrying next cycle")
else:
# No action available, wait and retry fetching from queue
time.sleep(idle_wait)
def stream_observations(self, get_observation_fn):
"""Continuously stream observations to the server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
print("Observation streaming thread starting")
first_observation = True
while self.running:
try:
# Get serialized observation bytes from the function
time.sleep(environment_dt)
observation = get_observation_fn()
# Skip if observation is None (couldn't acquire lock)
if observation is None:
continue
# Set appropriate transfer state
if first_observation:
state = async_inference_pb2.TRANSFER_BEGIN
first_observation = False
else:
state = async_inference_pb2.TRANSFER_MIDDLE
self.send_observation(observation, state)
except Exception as e:
print(f"Error in observation sender: {e}")
time.sleep(idle_wait)
def async_client():
# Example of how to use the RobotClient
client = RobotClient()
if client.start():
# Function to generate mock observations
def get_observation():
# Create a counter attribute if it doesn't exist
if not hasattr(get_observation, "counter"):
get_observation.counter = 0
# Acquire lock before accessing the robot
observation_content = None
if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
try:
observation_content = client.robot.capture_observation()
finally:
# Always release the lock in a finally block to ensure it's released
client.robot_lock.release()
else:
print("Could not acquire robot lock for observation capture, skipping this cycle")
return None # Return None to indicate no observation was captured
observation = TimedObservation(
timestamp=time.time(), observation=observation_content, timestep=get_observation.counter
)
# Increment counter for next call
get_observation.counter += 1
return observation
print("Starting all threads...")
# Create and start observation sender thread
obs_thread = threading.Thread(target=client.stream_observations, args=(get_observation,))
obs_thread.daemon = True
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions)
action_receiver_thread.daemon = True
# Create action execution thread
action_execution_thread = threading.Thread(target=client.execute_actions)
action_execution_thread.daemon = True
# Start all threads
obs_thread.start()
action_receiver_thread.start()
action_execution_thread.start()
try:
# Main thread just keeps everything alive
while client.running:
time.sleep(idle_wait)
except KeyboardInterrupt:
pass
finally:
client.stop()
print("Client stopped")
if __name__ == "__main__":
async_client()

View File

@@ -133,7 +133,7 @@ def train(cfg: TrainPipelineConfig):
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Creating policy")
policy = make_policy(

View File

@@ -60,16 +60,16 @@ dependencies = [
"jsonlines>=4.0.0",
"numba>=0.59.0",
"omegaconf>=2.3.0",
"opencv-python>=4.9.0",
"opencv-python-headless>=4.9.0",
"packaging>=24.2",
"av>=12.0.5,<13.0.0",
"av>=12.0.5",
"pymunk>=6.6.0",
"pynput>=1.7.7",
"pyzmq>=26.2.1",
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"torch>=2.2.1",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchvision>=0.21.0",
"wandb>=0.16.3",
"zarr>=2.17.0",

View File

@@ -172,8 +172,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
push_to_hub=False,
# TODO(rcadene, aliberts): test video=True
video=False,
# TODO(rcadene): display cameras through cv2 sometimes crashes on mac
display_cameras=False,
display_data=False,
play_sounds=False,
)
dataset = record(robot, rec_cfg)
@@ -226,7 +225,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
display_data=False,
play_sounds=False,
num_image_writer_processes=num_image_writer_processes,
)
@@ -273,7 +272,7 @@ def test_resume_record(tmp_path, request, robot_type, mock):
episode_time_s=1,
push_to_hub=False,
video=False,
display_cameras=False,
display_data=False,
play_sounds=False,
num_episodes=1,
)
@@ -330,7 +329,7 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
display_data=False,
play_sounds=False,
)
dataset = record(robot, rec_cfg)
@@ -380,7 +379,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
display_data=False,
play_sounds=False,
)
@@ -433,7 +432,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
display_data=False,
play_sounds=False,
num_image_writer_processes=num_image_writer_processes,
)