Package folder structure (#1417)

* Move files

* Replace imports & paths

* Update relative paths

* Update doc symlinks

* Update instructions paths

* Fix imports

* Update grpc files

* Update more instructions

* Downgrade grpc-tools

* Update manifest

* Update more paths

* Update config paths

* Update CI paths

* Update bandit exclusions

* Remove walkthrough section
This commit is contained in:
Simon Alibert
2025-07-01 16:34:46 +02:00
committed by GitHub
parent 483be9aac2
commit d4ee470b00
268 changed files with 862 additions and 890 deletions

212
src/lerobot/__init__.py Normal file
View File

@@ -0,0 +1,212 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
Example:
```python
import lerobot
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_datasets_per_env)
print(lerobot.available_real_world_datasets)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
print(lerobot.available_robots)
print(lerobot.available_cameras)
print(lerobot.available_motors)
```
When implementing a new dataset loadable with LeRobotDataset follow these steps:
- Update `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
"""
import itertools
from lerobot.__version__ import __version__ # noqa: F401
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
# a yaml file AND a environment name. The difference should be more obvious.
available_tasks_per_env = {
"aloha": [
"AlohaInsertion-v0",
"AlohaTransferCube-v0",
],
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
}
available_envs = list(available_tasks_per_env.keys())
available_datasets_per_env = {
"aloha": [
"lerobot/aloha_sim_insertion_human",
"lerobot/aloha_sim_insertion_scripted",
"lerobot/aloha_sim_transfer_cube_human",
"lerobot/aloha_sim_transfer_cube_scripted",
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_transfer_cube_scripted_image",
],
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
# coupled with tests.
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
"xarm": [
"lerobot/xarm_lift_medium",
"lerobot/xarm_lift_medium_replay",
"lerobot/xarm_push_medium",
"lerobot/xarm_push_medium_replay",
"lerobot/xarm_lift_medium_image",
"lerobot/xarm_lift_medium_replay_image",
"lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image",
],
}
available_real_world_datasets = [
"lerobot/aloha_mobile_cabinet",
"lerobot/aloha_mobile_chair",
"lerobot/aloha_mobile_elevator",
"lerobot/aloha_mobile_shrimp",
"lerobot/aloha_mobile_wash_pan",
"lerobot/aloha_mobile_wipe_wine",
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
"lerobot/aloha_static_coffee_new",
"lerobot/aloha_static_cups_open",
"lerobot/aloha_static_fork_pick_up",
"lerobot/aloha_static_pingpong_test",
"lerobot/aloha_static_pro_pencil",
"lerobot/aloha_static_screw_driver",
"lerobot/aloha_static_tape",
"lerobot/aloha_static_thread_velcro",
"lerobot/aloha_static_towel",
"lerobot/aloha_static_vinh_cup",
"lerobot/aloha_static_vinh_cup_left",
"lerobot/aloha_static_ziploc_slide",
"lerobot/umi_cup_in_the_wild",
"lerobot/unitreeh1_fold_clothes",
"lerobot/unitreeh1_rearrange_objects",
"lerobot/unitreeh1_two_robot_greeting",
"lerobot/unitreeh1_warehouse",
"lerobot/nyu_rot_dataset",
"lerobot/utokyo_saytap",
"lerobot/imperialcollege_sawyer_wrist_cam",
"lerobot/utokyo_xarm_bimanual",
"lerobot/tokyo_u_lsmo",
"lerobot/utokyo_pr2_opening_fridge",
"lerobot/cmu_franka_exploration_dataset",
"lerobot/cmu_stretch",
"lerobot/asu_table_top",
"lerobot/utokyo_pr2_tabletop_manipulation",
"lerobot/utokyo_xarm_pick_and_place",
"lerobot/ucsd_kitchen_dataset",
"lerobot/austin_buds_dataset",
"lerobot/dlr_sara_grid_clamp",
"lerobot/conq_hose_manipulation",
"lerobot/columbia_cairlab_pusht_real",
"lerobot/dlr_sara_pour",
"lerobot/dlr_edan_shared_control",
"lerobot/ucsd_pick_and_place_dataset",
"lerobot/berkeley_cable_routing",
"lerobot/nyu_franka_play_dataset",
"lerobot/austin_sirius_dataset",
"lerobot/cmu_play_fusion",
"lerobot/berkeley_gnm_sac_son",
"lerobot/nyu_door_opening_surprising_effectiveness",
"lerobot/berkeley_fanuc_manipulation",
"lerobot/jaco_play",
"lerobot/viola",
"lerobot/kaist_nonprehensile",
"lerobot/berkeley_mvp",
"lerobot/uiuc_d3field",
"lerobot/berkeley_gnm_recon",
"lerobot/austin_sailor_dataset",
"lerobot/utaustin_mutex",
"lerobot/roboturk",
"lerobot/stanford_hydra_dataset",
"lerobot/berkeley_autolab_ur5",
"lerobot/stanford_robocook",
"lerobot/toto",
"lerobot/fmb",
"lerobot/droid_100",
"lerobot/berkeley_rpt",
"lerobot/stanford_kuka_multimodal_dataset",
"lerobot/iamlab_cmu_pickup_insert",
"lerobot/taco_play",
"lerobot/berkeley_gnm_cory_hall",
"lerobot/usc_cloth_sim",
]
available_datasets = sorted(
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
)
# lists all available policies from `lerobot/policies`
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
# lists all available robots from `lerobot/robot_devices/robots`
available_robots = [
"koch",
"koch_bimanual",
"aloha",
"so100",
"so101",
]
# lists all available cameras from `lerobot/robot_devices/cameras`
available_cameras = [
"opencv",
"intelrealsense",
]
# lists all available motors from `lerobot/robot_devices/motors`
available_motors = [
"dynamixel",
"feetech",
]
# keys and values refer to yaml files
available_policies_per_env = {
"aloha": ["act"],
"pusht": ["diffusion", "vqbet"],
"xarm": ["tdmpc"],
"koch_real": ["act_koch_real"],
"aloha_real": ["act_aloha_real"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_dataset_pairs = [
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
]
env_dataset_policy_triplets = [
(env, dataset, policy)
for env, datasets in available_datasets_per_env.items()
for dataset in datasets
for policy in available_policies_per_env[env]
]

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""To enable `lerobot.__version__`"""
from importlib.metadata import PackageNotFoundError, version
try:
__version__ = version("lerobot")
except PackageNotFoundError:
__version__ = "unknown"

84
src/lerobot/calibrate.py Normal file
View File

@@ -0,0 +1,84 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper to recalibrate your device (robot or teleoperator).
Example:
```shell
python -m lerobot.calibrate \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue
```
"""
import logging
from dataclasses import asdict, dataclass
from pprint import pformat
import draccus
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
lekiwi,
make_robot_from_config,
so100_follower,
so101_follower,
)
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
koch_leader,
make_teleoperator_from_config,
so100_leader,
so101_leader,
)
from lerobot.utils.utils import init_logging
@dataclass
class CalibrateConfig:
teleop: TeleoperatorConfig | None = None
robot: RobotConfig | None = None
def __post_init__(self):
if bool(self.teleop) == bool(self.robot):
raise ValueError("Choose either a teleop or a robot.")
self.device = self.robot if self.robot else self.teleop
@draccus.wrap()
def calibrate(cfg: CalibrateConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
if isinstance(cfg.device, RobotConfig):
device = make_robot_from_config(cfg.device)
elif isinstance(cfg.device, TeleoperatorConfig):
device = make_teleoperator_from_config(cfg.device)
device.connect(calibrate=False)
device.calibrate()
device.disconnect()
if __name__ == "__main__":
calibrate()

View File

@@ -0,0 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .camera import Camera
from .configs import CameraConfig, ColorMode, Cv2Rotation
from .utils import make_cameras_from_configs

View File

@@ -0,0 +1,120 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Any, Dict, List
import numpy as np
from .configs import CameraConfig, ColorMode
class Camera(abc.ABC):
"""Base class for camera implementations.
Defines a standard interface for camera operations across different backends.
Subclasses must implement all abstract methods.
Manages basic camera properties (FPS, resolution) and core operations:
- Connection/disconnection
- Frame capture (sync/async)
Attributes:
fps (int | None): Configured frames per second
width (int | None): Frame width in pixels
height (int | None): Frame height in pixels
Example:
class MyCamera(Camera):
def __init__(self, config): ...
@property
def is_connected(self) -> bool: ...
def connect(self, warmup=True): ...
# Plus other required methods
"""
def __init__(self, config: CameraConfig):
"""Initialize the camera with the given configuration.
Args:
config: Camera configuration containing FPS and resolution.
"""
self.fps: int | None = config.fps
self.width: int | None = config.width
self.height: int | None = config.height
@property
@abc.abstractmethod
def is_connected(self) -> bool:
"""Check if the camera is currently connected.
Returns:
bool: True if the camera is connected and ready to capture frames,
False otherwise.
"""
pass
@staticmethod
@abc.abstractmethod
def find_cameras() -> List[Dict[str, Any]]:
"""Detects available cameras connected to the system.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains information about a detected camera.
"""
pass
@abc.abstractmethod
def connect(self, warmup: bool = True) -> None:
"""Establish connection to the camera.
Args:
warmup: If True (default), captures a warmup frame before returning. Useful
for cameras that require time to adjust capture settings.
If False, skips the warmup frame.
"""
pass
@abc.abstractmethod
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""Capture and return a single frame from the camera.
Args:
color_mode: Desired color mode for the output frame. If None,
uses the camera's default color mode.
Returns:
np.ndarray: Captured frame as a numpy array.
"""
pass
@abc.abstractmethod
def async_read(self, timeout_ms: float = ...) -> np.ndarray:
"""Asynchronously capture and return a single frame from the camera.
Args:
timeout_ms: Maximum time to wait for a frame in milliseconds.
Defaults to implementation-specific timeout.
Returns:
np.ndarray: Captured frame as a numpy array.
"""
pass
@abc.abstractmethod
def disconnect(self) -> None:
"""Disconnect from the camera and release resources."""
pass

View File

@@ -0,0 +1,44 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import dataclass
from enum import Enum
import draccus
class ColorMode(str, Enum):
RGB = "rgb"
BGR = "bgr"
class Cv2Rotation(int, Enum):
NO_ROTATION = 0
ROTATE_90 = 90
ROTATE_180 = 180
ROTATE_270 = -90
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
fps: int | None = None
width: int | None = None
height: int | None = None
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)

View File

@@ -0,0 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .camera_opencv import OpenCVCamera
from .configuration_opencv import OpenCVCameraConfig

View File

@@ -0,0 +1,482 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the OpenCVCamera class for capturing frames from cameras using OpenCV.
"""
import logging
import math
import platform
import time
from pathlib import Path
from threading import Event, Lock, Thread
from typing import Any, Dict, List
import cv2
import numpy as np
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..utils import get_cv2_backend, get_cv2_rotation
from .configuration_opencv import ColorMode, OpenCVCameraConfig
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case
# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23.
# When you change the USB port or reboot the computer, the operating system might
# treat the same cameras as new devices. Thus we select a higher bound to search indices.
MAX_OPENCV_INDEX = 60
logger = logging.getLogger(__name__)
class OpenCVCamera(Camera):
"""
Manages camera interactions using OpenCV for efficient frame recording.
This class provides a high-level interface to connect to, configure, and read
frames from cameras compatible with OpenCV's VideoCapture. It supports both
synchronous and asynchronous frame reading.
An OpenCVCamera instance requires a camera index (e.g., 0) or a device path
(e.g., '/dev/video0' on Linux). Camera indices can be unstable across reboots
or port changes, especially on Linux. Use the provided utility script to find
available camera indices or paths:
```bash
python -m lerobot.find_cameras opencv
```
The camera's default settings (FPS, resolution, color mode) are used unless
overridden in the configuration.
Example:
```python
from lerobot.cameras.opencv import OpenCVCamera
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation
# Basic usage with camera index 0
config = OpenCVCameraConfig(index_or_path=0)
camera = OpenCVCamera(config)
camera.connect()
# Read 1 frame synchronously
color_image = camera.read()
print(color_image.shape)
# Read 1 frame asynchronously
async_image = camera.async_read()
# When done, properly disconnect the camera using
camera.disconnect()
# Example with custom settings
custom_config = OpenCVCameraConfig(
index_or_path='/dev/video0', # Or use an index
fps=30,
width=1280,
height=720,
color_mode=ColorMode.RGB,
rotation=Cv2Rotation.ROTATE_90
)
custom_camera = OpenCVCamera(custom_config)
# ... connect, read, disconnect ...
```
"""
def __init__(self, config: OpenCVCameraConfig):
"""
Initializes the OpenCVCamera instance.
Args:
config: The configuration settings for the camera.
"""
super().__init__(config)
self.config = config
self.index_or_path = config.index_or_path
self.fps = config.fps
self.color_mode = config.color_mode
self.warmup_s = config.warmup_s
self.videocapture: cv2.VideoCapture | None = None
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
self.backend: int = get_cv2_backend()
if self.height and self.width:
self.capture_width, self.capture_height = self.width, self.height
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
self.capture_width, self.capture_height = self.height, self.width
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.index_or_path})"
@property
def is_connected(self) -> bool:
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
def connect(self, warmup: bool = True):
"""
Connects to the OpenCV camera specified in the configuration.
Initializes the OpenCV VideoCapture object, sets desired camera properties
(FPS, width, height), and performs initial checks.
Raises:
DeviceAlreadyConnectedError: If the camera is already connected.
ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open.
RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
# Use 1 thread for OpenCV operations to avoid potential conflicts or
# blocking in multi-threaded applications, especially during data collection.
cv2.setNumThreads(1)
self.videocapture = cv2.VideoCapture(self.index_or_path, self.backend)
if not self.videocapture.isOpened():
self.videocapture.release()
self.videocapture = None
raise ConnectionError(
f"Failed to open {self}."
f"Run `python -m lerobot.find_cameras opencv` to find available cameras."
)
self._configure_capture_settings()
if warmup:
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.read()
time.sleep(0.1)
logger.info(f"{self} connected.")
def _configure_capture_settings(self) -> None:
"""
Applies the specified FPS, width, and height settings to the connected camera.
This method attempts to set the camera properties via OpenCV. It checks if
the camera successfully applied the settings and raises an error if not.
Args:
fps: The desired frames per second. If None, the setting is skipped.
width: The desired capture width. If None, the setting is skipped.
height: The desired capture height. If None, the setting is skipped.
Raises:
RuntimeError: If the camera fails to set any of the specified properties
to the requested value.
DeviceNotConnectedError: If the camera is not connected when attempting
to configure settings.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
if self.fps is None:
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
else:
self._validate_fps()
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
if self.width is None or self.height is None:
self.width, self.height = default_width, default_height
self.capture_width, self.capture_height = default_width, default_height
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
self.width, self.height = default_height, default_width
self.capture_width, self.capture_height = default_width, default_height
else:
self._validate_width_and_height()
def _validate_fps(self) -> None:
"""Validates and sets the camera's frames per second (FPS)."""
success = self.videocapture.set(cv2.CAP_PROP_FPS, float(self.fps))
actual_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
# Use math.isclose for robust float comparison
if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
raise RuntimeError(f"{self} failed to set fps={self.fps} ({actual_fps=}).")
def _validate_width_and_height(self) -> None:
"""Validates and sets the camera's frame capture width and height."""
width_success = self.videocapture.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.capture_width))
height_success = self.videocapture.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.capture_height))
actual_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
if not width_success or self.capture_width != actual_width:
raise RuntimeError(
f"{self} failed to set capture_width={self.capture_width} ({actual_width=}, {width_success=})."
)
actual_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
if not height_success or self.capture_height != actual_height:
raise RuntimeError(
f"{self} failed to set capture_height={self.capture_height} ({actual_height=}, {height_success=})."
)
@staticmethod
def find_cameras() -> List[Dict[str, Any]]:
"""
Detects available OpenCV cameras connected to the system.
On Linux, it scans '/dev/video*' paths. On other systems (like macOS, Windows),
it checks indices from 0 up to `MAX_OPENCV_INDEX`.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains 'type', 'id' (port index or path),
and the default profile properties (width, height, fps, format).
"""
found_cameras_info = []
if platform.system() == "Linux":
possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name)
targets_to_scan = [str(p) for p in possible_paths]
else:
targets_to_scan = list(range(MAX_OPENCV_INDEX))
for target in targets_to_scan:
camera = cv2.VideoCapture(target)
if camera.isOpened():
default_width = int(camera.get(cv2.CAP_PROP_FRAME_WIDTH))
default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
default_fps = camera.get(cv2.CAP_PROP_FPS)
default_format = camera.get(cv2.CAP_PROP_FORMAT)
camera_info = {
"name": f"OpenCV Camera @ {target}",
"type": "OpenCV",
"id": target,
"backend_api": camera.getBackendName(),
"default_stream_profile": {
"format": default_format,
"width": default_width,
"height": default_height,
"fps": default_fps,
},
}
found_cameras_info.append(camera_info)
camera.release()
return found_cameras_info
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Reads a single frame synchronously from the camera.
This is a blocking call. It waits for the next available frame from the
camera hardware via OpenCV.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
Returns:
np.ndarray: The captured frame as a NumPy array in the format
(height, width, channels), using the specified or default
color mode and applying any configured rotation.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If reading the frame from the camera fails or if the
received frame dimensions don't match expectations before rotation.
ValueError: If an invalid `color_mode` is requested.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
ret, frame = self.videocapture.read()
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
processed_frame = self._postprocess_image(frame, color_mode)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return processed_frame
def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Applies color conversion, dimension validation, and rotation to a raw frame.
Args:
image (np.ndarray): The raw image frame (expected BGR format from OpenCV).
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
uses the instance's default `self.color_mode`.
Returns:
np.ndarray: The processed image frame.
Raises:
ValueError: If the requested `color_mode` is invalid.
RuntimeError: If the raw frame dimensions do not match the configured
`width` and `height`.
"""
requested_color_mode = self.color_mode if color_mode is None else color_mode
if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
h, w, c = image.shape
if h != self.capture_height or w != self.capture_width:
raise RuntimeError(
f"{self} frame width={w} or height={h} do not match configured width={self.capture_width} or height={self.capture_height}."
)
if c != 3:
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
processed_image = image
if requested_color_mode == ColorMode.RGB:
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
processed_image = cv2.rotate(processed_image, self.rotation)
return processed_image
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
while not self.stop_event.is_set():
try:
color_image = self.read()
with self.frame_lock:
self.latest_frame = color_image
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame asynchronously.
This method retrieves the most recent frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms (0.2 seconds).
Returns:
np.ndarray: The latest captured frame as a NumPy array in the format
(height, width, channels), processed according to configuration.
Raises:
DeviceNotConnectedError: If the camera is not connected.
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
)
with self.frame_lock:
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None:
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
return frame
def disconnect(self):
"""
Disconnects from the camera and cleans up resources.
Stops the background read thread (if running) and releases the OpenCV
VideoCapture object.
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected and self.thread is None:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.thread is not None:
self._stop_read_thread()
if self.videocapture is not None:
self.videocapture.release()
self.videocapture = None
logger.info(f"{self} disconnected.")

View File

@@ -0,0 +1,73 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Rotation
@CameraConfig.register_subclass("opencv")
@dataclass
class OpenCVCameraConfig(CameraConfig):
"""Configuration class for OpenCV-based camera devices or video files.
This class provides configuration options for cameras accessed through OpenCV,
supporting both physical camera devices and video files. It includes settings
for resolution, frame rate, color mode, and image rotation.
Example configurations:
```python
# Basic configurations
OpenCVCameraConfig(0, 30, 1280, 720) # 1280x720 @ 30FPS
OpenCVCameraConfig(/dev/video4, 60, 640, 480) # 640x480 @ 60FPS
# Advanced configurations
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
```
Attributes:
index_or_path: Either an integer representing the camera device index,
or a Path object pointing to a video file.
fps: Requested frames per second for the color stream.
width: Requested frame width in pixels for the color stream.
height: Requested frame height in pixels for the color stream.
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
"""
index_or_path: int | Path
color_mode: ColorMode = ColorMode.RGB
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
def __post_init__(self):
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)

View File

@@ -0,0 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .camera_realsense import RealSenseCamera
from .configuration_realsense import RealSenseCameraConfig

View File

@@ -0,0 +1,556 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the RealSenseCamera class for capturing frames from Intel RealSense cameras.
"""
import logging
import time
from threading import Event, Lock, Thread
from typing import Any, Dict, List
import cv2
import numpy as np
try:
import pyrealsense2 as rs
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
from ..utils import get_cv2_rotation
from .configuration_realsense import RealSenseCameraConfig
logger = logging.getLogger(__name__)
class RealSenseCamera(Camera):
"""
Manages interactions with Intel RealSense cameras for frame and depth recording.
This class provides an interface similar to `OpenCVCamera` but tailored for
RealSense devices, leveraging the `pyrealsense2` library. It uses the camera's
unique serial number for identification, offering more stability than device
indices, especially on Linux. It also supports capturing depth maps alongside
color frames.
Use the provided utility script to find available camera indices and default profiles:
```bash
python -m lerobot.find_cameras realsense
```
A `RealSenseCamera` instance requires a configuration object specifying the
camera's serial number or a unique device name. If using the name, ensure only
one camera with that name is connected.
The camera's default settings (FPS, resolution, color mode) from the stream
profile are used unless overridden in the configuration.
Example:
```python
from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
from lerobot.cameras import ColorMode, Cv2Rotation
# Basic usage with serial number
config = RealSenseCameraConfig(serial_number_or_name="0123456789") # Replace with actual SN
camera = RealSenseCamera(config)
camera.connect()
# Read 1 frame synchronously
color_image = camera.read()
print(color_image.shape)
# Read 1 frame asynchronously
async_image = camera.async_read()
# When done, properly disconnect the camera using
camera.disconnect()
# Example with depth capture and custom settings
custom_config = RealSenseCameraConfig(
serial_number_or_name="0123456789", # Replace with actual SN
fps=30,
width=1280,
height=720,
color_mode=ColorMode.BGR, # Request BGR output
rotation=Cv2Rotation.NO_ROTATION,
use_depth=True
)
depth_camera = RealSenseCamera(custom_config)
depth_camera.connect()
# Read 1 depth frame
depth_map = depth_camera.read_depth()
# Example using a unique camera name
name_config = RealSenseCameraConfig(serial_number_or_name="Intel RealSense D435") # If unique
name_camera = RealSenseCamera(name_config)
# ... connect, read, disconnect ...
```
"""
def __init__(self, config: RealSenseCameraConfig):
"""
Initializes the RealSenseCamera instance.
Args:
config: The configuration settings for the camera.
"""
super().__init__(config)
self.config = config
if config.serial_number_or_name.isdigit():
self.serial_number = config.serial_number_or_name
else:
self.serial_number = self._find_serial_number_from_name(config.serial_number_or_name)
self.fps = config.fps
self.color_mode = config.color_mode
self.use_depth = config.use_depth
self.warmup_s = config.warmup_s
self.rs_pipeline: rs.pipeline | None = None
self.rs_profile: rs.pipeline_profile | None = None
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
if self.height and self.width:
self.capture_width, self.capture_height = self.width, self.height
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
self.capture_width, self.capture_height = self.height, self.width
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.serial_number})"
@property
def is_connected(self) -> bool:
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
def connect(self, warmup: bool = True):
"""
Connects to the RealSense camera specified in the configuration.
Initializes the RealSense pipeline, configures the required streams (color
and optionally depth), starts the pipeline, and validates the actual stream settings.
Raises:
DeviceAlreadyConnectedError: If the camera is already connected.
ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique).
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
RuntimeError: If the pipeline starts but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
self.rs_pipeline = rs.pipeline()
rs_config = rs.config()
self._configure_rs_pipeline_config(rs_config)
try:
self.rs_profile = self.rs_pipeline.start(rs_config)
except RuntimeError as e:
self.rs_profile = None
self.rs_pipeline = None
raise ConnectionError(
f"Failed to open {self}."
"Run `python -m lerobot.find_cameras realsense` to find available cameras."
) from e
self._configure_capture_settings()
if warmup:
time.sleep(
1
) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise.
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.read()
time.sleep(0.1)
logger.info(f"{self} connected.")
@staticmethod
def find_cameras() -> List[Dict[str, Any]]:
"""
Detects available Intel RealSense cameras connected to the system.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains 'type', 'id' (serial number), 'name',
firmware version, USB type, and other available specs, and the default profile properties (width, height, fps, format).
Raises:
OSError: If pyrealsense2 is not installed.
ImportError: If pyrealsense2 is not installed.
"""
found_cameras_info = []
context = rs.context()
devices = context.query_devices()
for device in devices:
camera_info = {
"name": device.get_info(rs.camera_info.name),
"type": "RealSense",
"id": device.get_info(rs.camera_info.serial_number),
"firmware_version": device.get_info(rs.camera_info.firmware_version),
"usb_type_descriptor": device.get_info(rs.camera_info.usb_type_descriptor),
"physical_port": device.get_info(rs.camera_info.physical_port),
"product_id": device.get_info(rs.camera_info.product_id),
"product_line": device.get_info(rs.camera_info.product_line),
}
# Get stream profiles for each sensor
sensors = device.query_sensors()
for sensor in sensors:
profiles = sensor.get_stream_profiles()
for profile in profiles:
if profile.is_video_stream_profile() and profile.is_default():
vprofile = profile.as_video_stream_profile()
stream_info = {
"stream_type": vprofile.stream_name(),
"format": vprofile.format().name,
"width": vprofile.width(),
"height": vprofile.height(),
"fps": vprofile.fps(),
}
camera_info["default_stream_profile"] = stream_info
found_cameras_info.append(camera_info)
return found_cameras_info
def _find_serial_number_from_name(self, name: str) -> str:
"""Finds the serial number for a given unique camera name."""
camera_infos = self.find_cameras()
found_devices = [cam for cam in camera_infos if str(cam["name"]) == name]
if not found_devices:
available_names = [cam["name"] for cam in camera_infos]
raise ValueError(
f"No RealSense camera found with name '{name}'. Available camera names: {available_names}"
)
if len(found_devices) > 1:
serial_numbers = [dev["serial_number"] for dev in found_devices]
raise ValueError(
f"Multiple RealSense cameras found with name '{name}'. "
f"Please use a unique serial number instead. Found SNs: {serial_numbers}"
)
serial_number = str(found_devices[0]["serial_number"])
return serial_number
def _configure_rs_pipeline_config(self, rs_config):
"""Creates and configures the RealSense pipeline configuration object."""
rs.config.enable_device(rs_config, self.serial_number)
if self.width and self.height and self.fps:
rs_config.enable_stream(
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
)
if self.use_depth:
rs_config.enable_stream(
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
)
else:
rs_config.enable_stream(rs.stream.color)
if self.use_depth:
rs_config.enable_stream(rs.stream.depth)
def _configure_capture_settings(self) -> None:
"""Sets fps, width, and height from device stream if not already configured.
Uses the color stream profile to update unset attributes. Handles rotation by
swapping width/height when needed. Original capture dimensions are always stored.
Raises:
DeviceNotConnectedError: If device is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()
if self.fps is None:
self.fps = stream.fps()
if self.width is None or self.height is None:
actual_width = int(round(stream.width()))
actual_height = int(round(stream.height()))
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
self.width, self.height = actual_height, actual_width
self.capture_width, self.capture_height = actual_width, actual_height
else:
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
def read_depth(self, timeout_ms: int = 200) -> np.ndarray:
"""
Reads a single frame (depth) synchronously from the camera.
This is a blocking call. It waits for a coherent set of frames (depth)
from the camera hardware via the RealSense pipeline.
Args:
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
Returns:
np.ndarray: The depth map as a NumPy array (height, width)
of type `np.uint16` (raw depth values in millimeters) and rotation.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If reading frames from the pipeline fails or frames are invalid.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if not self.use_depth:
raise RuntimeError(
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
)
start_time = time.perf_counter()
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
raise RuntimeError(f"{self} read_depth failed (status={ret}).")
depth_frame = frame.get_depth_frame()
depth_map = np.asanyarray(depth_frame.get_data())
depth_map_processed = self._postprocess_image(depth_map, depth_frame=True)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return depth_map_processed
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray:
"""
Reads a single frame (color) synchronously from the camera.
This is a blocking call. It waits for a coherent set of frames (color)
from the camera hardware via the RealSense pipeline.
Args:
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
Returns:
np.ndarray: The captured color frame as a NumPy array
(height, width, channels), processed according to `color_mode` and rotation.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If reading frames from the pipeline fails or frames are invalid.
ValueError: If an invalid `color_mode` is requested.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
color_frame = frame.get_color_frame()
color_image_raw = np.asanyarray(color_frame.get_data())
color_image_processed = self._postprocess_image(color_image_raw, color_mode)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return color_image_processed
def _postprocess_image(
self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False
) -> np.ndarray:
"""
Applies color conversion, dimension validation, and rotation to a raw color frame.
Args:
image (np.ndarray): The raw image frame (expected RGB format from RealSense).
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
uses the instance's default `self.color_mode`.
Returns:
np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`.
Raises:
ValueError: If the requested `color_mode` is invalid.
RuntimeError: If the raw frame dimensions do not match the configured
`width` and `height`.
"""
if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
if depth_frame:
h, w = image.shape
else:
h, w, c = image.shape
if c != 3:
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
if h != self.capture_height or w != self.capture_width:
raise RuntimeError(
f"{self} frame width={w} or height={h} do not match configured width={self.capture_width} or height={self.capture_height}."
)
processed_image = image
if self.color_mode == ColorMode.BGR:
processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
processed_image = cv2.rotate(processed_image, self.rotation)
return processed_image
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame with 500ms timeout
2. Stores result in latest_frame (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
while not self.stop_event.is_set():
try:
color_image = self.read(timeout_ms=500)
with self.frame_lock:
self.latest_frame = color_image
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self):
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
# NOTE(Steven): Missing implementation for depth for now
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame data (color) asynchronously.
This method retrieves the most recent color frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms (0.2 seconds).
Returns:
np.ndarray:
The latest captured frame data (color image), processed according to configuration.
Raises:
DeviceNotConnectedError: If the camera is not connected.
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread died unexpectedly or another error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
)
with self.frame_lock:
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None:
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
return frame
def disconnect(self):
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.
Stops the background read thread (if running) and stops the RealSense pipeline.
Raises:
DeviceNotConnectedError: If the camera is already disconnected (pipeline not running).
"""
if not self.is_connected and self.thread is None:
raise DeviceNotConnectedError(
f"Attempted to disconnect {self}, but it appears already disconnected."
)
if self.thread is not None:
self._stop_read_thread()
if self.rs_pipeline is not None:
self.rs_pipeline.stop()
self.rs_pipeline = None
self.rs_profile = None
logger.info(f"{self} disconnected.")

View File

@@ -0,0 +1,82 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode, Cv2Rotation
@CameraConfig.register_subclass("intelrealsense")
@dataclass
class RealSenseCameraConfig(CameraConfig):
"""Configuration class for Intel RealSense cameras.
This class provides specialized configuration options for Intel RealSense cameras,
including support for depth sensing and device identification via serial number or name.
Example configurations for Intel RealSense D405:
```python
# Basic configurations
RealSenseCameraConfig("0123456789", 30, 1280, 720) # 1280x720 @ 30FPS
RealSenseCameraConfig("0123456789", 60, 640, 480) # 640x480 @ 60FPS
# Advanced configurations
RealSenseCameraConfig("0123456789", 30, 640, 480, use_depth=True) # With depth sensing
RealSenseCameraConfig("0123456789", 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
```
Attributes:
fps: Requested frames per second for the color stream.
width: Requested frame width in pixels for the color stream.
height: Requested frame height in pixels for the color stream.
serial_number_or_name: Unique serial number or human-readable name to identify the camera.
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
use_depth: Whether to enable depth stream. Defaults to False.
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
Note:
- Either name or serial_number must be specified.
- Depth stream configuration (if enabled) will use the same FPS as the color stream.
- The actual resolution and FPS may be adjusted by the camera to the nearest supported mode.
- For `fps`, `width` and `height`, either all of them need to be set, or none of them.
"""
serial_number_or_name: str
color_mode: ColorMode = ColorMode.RGB
use_depth: bool = False
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
def __post_init__(self):
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
values = (self.fps, self.width, self.height)
if any(v is not None for v in values) and any(v is None for v in values):
raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them."
)

View File

@@ -0,0 +1,65 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from pathlib import Path
from typing import TypeAlias
from .camera import Camera
from .configs import CameraConfig, Cv2Rotation
IndexOrPath: TypeAlias = int | Path
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
cameras = {}
for key, cfg in camera_configs.items():
if cfg.type == "opencv":
from .opencv import OpenCVCamera
cameras[key] = OpenCVCamera(cfg)
elif cfg.type == "intelrealsense":
from .realsense.camera_realsense import RealSenseCamera
cameras[key] = RealSenseCamera(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
return cameras
def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
import cv2
if rotation == Cv2Rotation.ROTATE_90:
return cv2.ROTATE_90_CLOCKWISE
elif rotation == Cv2Rotation.ROTATE_180:
return cv2.ROTATE_180
elif rotation == Cv2Rotation.ROTATE_270:
return cv2.ROTATE_90_COUNTERCLOCKWISE
else:
return None
def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return cv2.CAP_AVFOUNDATION
else:
return cv2.CAP_ANY

View File

@@ -0,0 +1,71 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot import (
policies, # noqa: F401
)
from lerobot.datasets.transforms import ImageTransformsConfig
from lerobot.datasets.video_utils import get_safe_default_codec
@dataclass
class DatasetConfig:
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
@dataclass
class WandBConfig:
enable: bool = False
# Set to true to disable saving an artifact despite training.save_checkpoint=True
disable_artifact: bool = False
project: str = "lerobot"
entity: str | None = None
notes: str | None = None
run_id: str | None = None
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
@dataclass
class EvalConfig:
n_episodes: int = 50
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
batch_size: int = 50
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: bool = False
def __post_init__(self):
if self.batch_size > self.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
f"({self.batch_size} > {self.n_episodes}). As a result, {self.batch_size} "
f"eval environments will be instantiated, but only {self.n_episodes} will be used. "
"This might significantly slow down evaluation. To fix this, you should update your command "
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
)

View File

@@ -0,0 +1,65 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime as dt
import logging
from dataclasses import dataclass, field
from pathlib import Path
from lerobot import envs, policies # noqa: F401
from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
@dataclass
class EvalPipelineConfig:
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch
# (useful for debugging). This argument is mutually exclusive with `--config`.
env: envs.EnvConfig
eval: EvalConfig = field(default_factory=EvalConfig)
policy: PreTrainedConfig | None = None
output_dir: Path | None = None
job_name: str | None = None
seed: int | None = 1000
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
)
if not self.job_name:
if self.env is None:
self.job_name = f"{self.policy.type}"
else:
self.job_name = f"{self.env.type}_{self.policy.type}"
if not self.output_dir:
now = dt.datetime.now()
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/eval") / eval_dir
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]

View File

@@ -0,0 +1,231 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import pkgutil
import sys
from argparse import ArgumentError
from functools import wraps
from pathlib import Path
from typing import Sequence
import draccus
from lerobot.utils.utils import has_method
PATH_KEY = "path"
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
"""Parses arguments from cli at a given nested attribute level.
For example, supposing the main script was called with:
python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path
If called during execution of myscript.py, get_cli_overrides("arg2") will return:
["--subarg1=abc" "--subarg2=some/path"]
"""
if args is None:
args = sys.argv[1:]
attr_level_args = []
detect_string = f"--{field_name}."
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
for arg in args:
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
denested_arg = f"--{arg.removeprefix(detect_string)}"
attr_level_args.append(denested_arg)
return attr_level_args
def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
if args is None:
args = sys.argv[1:]
prefix = f"--{arg_name}="
for arg in args:
if arg.startswith(prefix):
return arg[len(prefix) :]
return None
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
"""Parse plugin-related arguments from command-line arguments.
This function extracts arguments from command-line arguments that match a specified suffix pattern.
It processes arguments in the format '--key=value' and returns them as a dictionary.
Args:
plugin_arg_suffix (str): The suffix to identify plugin-related arguments.
cli_args (Sequence[str]): A sequence of command-line arguments to parse.
Returns:
dict: A dictionary containing the parsed plugin arguments where:
- Keys are the argument names (with '--' prefix removed if present)
- Values are the corresponding argument values
Example:
>>> args = ['--env.discover_packages_path=my_package',
... '--other_arg=value']
>>> parse_plugin_args('discover_packages_path', args)
{'env.discover_packages_path': 'my_package'}
"""
plugin_args = {}
for arg in args:
if "=" in arg and plugin_arg_suffix in arg:
key, value = arg.split("=", 1)
# Remove leading '--' if present
if key.startswith("--"):
key = key[2:]
plugin_args[key] = value
return plugin_args
class PluginLoadError(Exception):
"""Raised when a plugin fails to load."""
def load_plugin(plugin_path: str) -> None:
"""Load and initialize a plugin from a given Python package path.
This function attempts to load a plugin by importing its package and any submodules.
Plugin registration is expected to happen during package initialization, i.e. when
the package is imported the gym environment should be registered and the config classes
registered with their parents using the `register_subclass` decorator.
Args:
plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin")
Raises:
PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
Examples:
>>> load_plugin("external_plugin.core") # Loads plugin from external package
Notes:
- The plugin package should handle its own registration during import
- All submodules in the plugin package will be imported
- Implementation follows the plugin discovery pattern from Python packaging guidelines
See Also:
https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/
"""
try:
package_module = importlib.import_module(plugin_path, __package__)
except (ImportError, ModuleNotFoundError) as e:
raise PluginLoadError(
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
) from e
def iter_namespace(ns_pkg):
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
try:
for _finder, pkg_name, _ispkg in iter_namespace(package_module):
importlib.import_module(pkg_name)
except ImportError as e:
raise PluginLoadError(
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
) from e
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
return parse_arg(f"{field_name}.{PATH_KEY}", args)
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args)
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
"""
Filters command-line arguments related to fields with specific path arguments.
Args:
fields_to_filter (str | list[str]): A single str or a list of str whose arguments need to be filtered.
args (Sequence[str] | None): The sequence of command-line arguments to be filtered.
Defaults to None.
Returns:
list[str]: A filtered list of arguments, with arguments related to the specified
fields removed.
Raises:
ArgumentError: If both a path argument (e.g., `--field_name.path`) and a type
argument (e.g., `--field_name.type`) are specified for the same field.
"""
if isinstance(fields_to_filter, str):
fields_to_filter = [fields_to_filter]
filtered_args = args
for field in fields_to_filter:
if get_path_arg(field, args):
if get_type_arg(field, args):
raise ArgumentError(
argument=None,
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
)
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
return filtered_args
def wrap(config_path: Path | None = None):
"""
HACK: Similar to draccus.wrap but does three additional things:
- Will remove '.path' arguments from CLI in order to process them later on.
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
initialize it from there to allow to fetch configs from the hub directly
- Will load plugins specified in the CLI arguments. These plugins will typically register
their own subclasses of config classes, so that draccus can find the right class to instantiate
from the CLI '.type' arguments
"""
def wrapper_outer(fn):
@wraps(fn)
def wrapper_inner(*args, **kwargs):
argspec = inspect.getfullargspec(fn)
argtype = argspec.annotations[argspec.args[0]]
if len(args) > 0 and type(args[0]) is argtype:
cfg = args[0]
args = args[1:]
else:
cli_args = sys.argv[1:]
plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
for plugin_cli_arg, plugin_path in plugin_args.items():
try:
load_plugin(plugin_path)
except PluginLoadError as e:
# add the relevant CLI arg to the error message
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
cli_args = filter_arg(plugin_cli_arg, cli_args)
config_path_cli = parse_arg("config_path", cli_args)
if has_method(argtype, "__get_path_fields__"):
path_fields = argtype.__get_path_fields__()
cli_args = filter_path_args(path_fields, cli_args)
if has_method(argtype, "from_pretrained") and config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
response = fn(cfg, *args, **kwargs)
return response
return wrapper_inner
return wrapper_outer

View File

@@ -0,0 +1,190 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Type, TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
# Generic variable that is either PreTrainedConfig or a subclass thereof
T = TypeVar("T", bound="PreTrainedConfig")
@dataclass
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
Base configuration class for policy models.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
"""
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
push_to_hub: bool = True
repo_id: str | None = None
# Upload on private repository on the Hugging Face hub.
private: bool | None = None
# Add tags to your policy on the hub.
tags: list[str] | None = None
# Add tags to your policy on the hub.
license: str | None = None
def __post_init__(self):
self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@property
@abc.abstractmethod
def observation_delta_indices(self) -> list | None:
raise NotImplementedError
@property
@abc.abstractmethod
def action_delta_indices(self) -> list | None:
raise NotImplementedError
@property
@abc.abstractmethod
def reward_delta_indices(self) -> list | None:
raise NotImplementedError
@abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError
@abc.abstractmethod
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
raise NotImplementedError
@abc.abstractmethod
def validate_features(self) -> None:
raise NotImplementedError
@property
def robot_state_feature(self) -> PolicyFeature | None:
for _, ft in self.input_features.items():
if ft.type is FeatureType.STATE:
return ft
return None
@property
def env_state_feature(self) -> PolicyFeature | None:
for _, ft in self.input_features.items():
if ft.type is FeatureType.ENV:
return ft
return None
@property
def image_features(self) -> dict[str, PolicyFeature]:
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
@property
def action_feature(self) -> PolicyFeature | None:
for _, ft in self.output_features.items():
if ft.type is FeatureType.ACTION:
return ft
return None
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: Type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**policy_kwargs,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
# something like --policy.path (in addition to --policy.type)
cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_overrides)

View File

@@ -0,0 +1,184 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime as dt
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Type
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.optim import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.hub import HubMixin
TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None
job_name: str | None = None
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
# `dir` is the directory of an existing run with at least one checkpoint in it.
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
resume: bool = False
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: int | None = 1000
# Number of workers for the dataloader.
num_workers: int = 4
batch_size: int = 8
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: int = 20_000
use_policy_training_preset: bool = True
optimizer: OptimizerConfig | None = None
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
def __post_init__(self):
self.checkpoint_path = None
def validate(self):
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
# Only load the policy config
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path")
if not config_path:
raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
policy_path = Path(config_path).parent
self.policy.pretrained_path = policy_path
self.checkpoint_path = policy_path.parent
if not self.job_name:
if self.env is None:
self.job_name = f"{self.policy.type}"
else:
self.job_name = f"{self.env.type}_{self.policy.type}"
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
raise FileExistsError(
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
f"Please change your output directory so that {self.output_dir} is not overwritten."
)
elif not self.output_dir:
now = dt.datetime.now()
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
elif self.use_policy_training_preset and not self.resume:
self.optimizer = self.policy.get_optimizer_preset()
self.scheduler = self.policy.get_scheduler_preset()
if self.policy.push_to_hub and not self.policy.repo_id:
raise ValueError(
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def to_dict(self) -> dict:
return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: Type["TrainPipelineConfig"],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**kwargs,
) -> "TrainPipelineConfig":
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if TRAIN_CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
else:
print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
elif Path(model_id).is_file():
config_file = model_id
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=TRAIN_CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
cli_args = kwargs.pop("cli_args", [])
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args)
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset

View File

@@ -0,0 +1,42 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: We subclass str so that serialization is straightforward
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
from dataclasses import dataclass
from enum import Enum
from typing import Any, Protocol
class FeatureType(str, Enum):
STATE = "STATE"
VISUAL = "VISUAL"
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"
class NormalizationMode(str, Enum):
MIN_MAX = "MIN_MAX"
MEAN_STD = "MEAN_STD"
IDENTITY = "IDENTITY"
class DictLike(Protocol):
def __getitem__(self, key: Any) -> Any: ...
@dataclass
class PolicyFeature:
type: FeatureType
shape: tuple

54
src/lerobot/constants.py Normal file
View File

@@ -0,0 +1,54 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# keys
import os
from pathlib import Path
from huggingface_hub.constants import HF_HOME
OBS_ENV_STATE = "observation.environment_state"
OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"
REWARD = "next.reward"
ROBOTS = "robots"
ROBOT_TYPE = "robot_type"
TELEOPERATORS = "teleoperators"
# files & directories
CHECKPOINTS_DIR = "checkpoints"
LAST_CHECKPOINT_LINK = "last"
PRETRAINED_MODEL_DIR = "pretrained_model"
TRAINING_STATE_DIR = "training_state"
RNG_STATE = "rng_state.safetensors"
TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
if "LEROBOT_HOME" in os.environ:
raise ValueError(
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
)
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
# calibration dir
default_calibration_path = HF_LEROBOT_HOME / "calibration"
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()

View File

@@ -0,0 +1,68 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import packaging.version
V2_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
V21_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
```
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id={repo_id}
```
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
FUTURE_MESSAGE = """
The dataset you requested ({repo_id}) is only available in {version} format.
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
"""
class CompatibilityError(Exception): ...
class BackwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)
class ForwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)

View File

@@ -0,0 +1,27 @@
---
# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/datasets-cards
{{ card_data }}
---
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
## Dataset Description
{{ dataset_description | default("", true) }}
- **Homepage:** {{ url | default("[More Information Needed]", true)}}
- **Paper:** {{ paper | default("[More Information Needed]", true)}}
- **License:** {{ license | default("[More Information Needed]", true)}}
## Dataset Structure
{{ dataset_structure | default("[More Information Needed]", true)}}
## Citation
**BibTeX:**
```bibtex
{{ citation_bibtex | default("[More Information Needed]", true)}}
```

View File

@@ -0,0 +1,176 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from lerobot.datasets.utils import load_image_as_numpy
def estimate_num_samples(
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
) -> int:
"""Heuristic to estimate the number of samples based on dataset size.
The power controls the sample growth relative to dataset size.
Lower the power for less number of samples.
For default arguments, we have:
- from 1 to ~500, num_samples=100
- at 1000, num_samples=177
- at 2000, num_samples=299
- at 5000, num_samples=594
- at 10000, num_samples=1000
- at 20000, num_samples=1681
"""
if dataset_len < min_num_samples:
min_num_samples = dataset_len
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
def sample_indices(data_len: int) -> list[int]:
num_samples = estimate_num_samples(data_len)
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
_, height, width = img.shape
if max(width, height) < max_size_threshold:
# no downsampling needed
return img
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
return img[:, ::downsample_factor, ::downsample_factor]
def sample_images(image_paths: list[str]) -> np.ndarray:
sampled_indices = sample_indices(len(image_paths))
images = None
for i, idx in enumerate(sampled_indices):
path = image_paths[idx]
# we load as uint8 to reduce memory usage
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
images[i] = img
return images
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
return {
"min": np.min(array, axis=axis, keepdims=keepdims),
"max": np.max(array, axis=axis, keepdims=keepdims),
"mean": np.mean(array, axis=axis, keepdims=keepdims),
"std": np.std(array, axis=axis, keepdims=keepdims),
"count": np.array([len(array)]),
}
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue # HACK: we should receive np.arrays of strings
elif features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data) # data is a list of image paths
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
else:
ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
# finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
for i in range(len(stats_list)):
for fkey in stats_list[i]:
for k, v in stats_list[i][fkey].items():
if not isinstance(v, np.ndarray):
raise ValueError(
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
)
if v.ndim == 0:
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
if k == "count" and v.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
counts = np.stack([s["count"] for s in stats_ft_list])
total_count = counts.sum(axis=0)
# Prepare weighted mean by matching number of dimensions
while counts.ndim < means.ndim:
counts = np.expand_dims(counts, axis=-1)
# Compute the weighted mean
weighted_means = means * counts
total_mean = weighted_means.sum(axis=0) / total_count
# Compute the variance using the parallel algorithm
delta_means = means - total_mean
weighted_variances = (variances + delta_means**2) * counts
total_variance = weighted_variances.sum(axis=0) / total_count
return {
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"mean": total_mean,
"std": np.sqrt(total_variance),
"count": total_count,
}
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
The final stats will have the union of all data keys from each of the stats dicts.
For instance:
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_mean = (mean of all data, weighted by counts)
- new_std = (std of all data)
"""
_assert_type_and_shape(stats_list)
data_keys = {key for stats in stats_list for key in stats}
aggregated_stats = {key: {} for key in data_keys}
for key in data_keys:
stats_with_key = [stats[key] for stats in stats_list if key in stats]
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
return aggregated_stats

View File

@@ -0,0 +1,118 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pprint import pformat
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.datasets.transforms import ImageTransforms
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}
def resolve_delta_timestamps(
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
) -> dict[str, list] | None:
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
Args:
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
delta_timestamps against.
Returns:
dict[str, list] | None: A dictionary of delta_timestamps, e.g.:
{
"observation.state": [-0.04, -0.02, 0]
"observation.action": [-0.02, 0, 0.02]
}
returns `None` if the resulting dict is empty.
"""
delta_timestamps = {}
for key in ds_meta.features:
if key == "next.reward" and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == "action" and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0:
delta_timestamps = None
return delta_timestamps
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
Args:
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
Raises:
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
Returns:
LeRobotDataset | MultiLeRobotDataset
"""
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset(
cfg.dataset.repo_id,
# TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend,
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}"
)
if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset

View File

@@ -0,0 +1,178 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import queue
import threading
from pathlib import Path
import numpy as np
import PIL.Image
import torch
def safe_stop_image_writer(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
dataset = kwargs.get("dataset")
image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None:
print("Waiting for image writer to terminate...")
image_writer.stop()
raise e
return wrapper
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
elif image_array.shape[-1] != 3:
raise NotImplementedError(
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
)
if image_array.dtype != np.uint8:
if range_check:
max_ = image_array.max().item()
min_ = image_array.min().item()
if max_ > 1.0 or min_ < 0.0:
raise ValueError(
"The image data type is float, which requires values in the range [0.0, 1.0]. "
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
"provide a uint8 image with values in the range [0, 255]."
)
image_array = (image_array * 255).astype(np.uint8)
return PIL.Image.fromarray(image_array)
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
try:
if isinstance(image, np.ndarray):
img = image_array_to_pil_image(image)
elif isinstance(image, PIL.Image.Image):
img = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath)
except Exception as e:
print(f"Error writing image {fpath}: {e}")
def worker_thread_loop(queue: queue.Queue):
while True:
item = queue.get()
if item is None:
queue.task_done()
break
image_array, fpath = item
write_image(image_array, fpath)
queue.task_done()
def worker_process(queue: queue.Queue, num_threads: int):
threads = []
for _ in range(num_threads):
t = threading.Thread(target=worker_thread_loop, args=(queue,))
t.daemon = True
t.start()
threads.append(t)
for t in threads:
t.join()
class AsyncImageWriter:
"""
This class abstract away the initialisation of processes or/and threads to
save images on disk asynchronously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it creates a threads pool of size `num_threads`.
When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
their own threads pool of size `num_threads`.
The optimal number of processes and threads depends on your computer capabilities.
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
def __init__(self, num_processes: int = 0, num_threads: int = 1):
self.num_processes = num_processes
self.num_threads = num_threads
self.queue = None
self.threads = []
self.processes = []
self._stopped = False
if num_threads <= 0 and num_processes <= 0:
raise ValueError("Number of threads and processes must be greater than zero.")
if self.num_processes == 0:
# Use threading
self.queue = queue.Queue()
for _ in range(self.num_threads):
t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
t.daemon = True
t.start()
self.threads.append(t)
else:
# Use multiprocessing
self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes):
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
p.daemon = True
p.start()
self.processes.append(p)
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy()
self.queue.put((image, fpath))
def wait_until_done(self):
self.queue.join()
def stop(self):
if self._stopped:
return
if self.num_processes == 0:
for _ in self.threads:
self.queue.put(None)
for t in self.threads:
t.join()
else:
num_nones = self.num_processes * self.num_threads
for _ in range(num_nones):
self.queue.put(None)
for p in self.processes:
p.join()
if p.is_alive():
p.terminate()
self.queue.close()
self.queue.join_thread()
self._stopped = True

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,384 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An online buffer for the online training loop in train.py
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
supports in-place slicing and mutation which is very handy for a dynamic buffer.
"""
import os
from pathlib import Path
from typing import Any
import numpy as np
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def _make_memmap_safe(**kwargs) -> np.memmap:
"""Make a numpy memmap with checks on available disk space first.
Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"
For information on dtypes:
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
"""
if kwargs["mode"].startswith("w"):
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
stats = os.statvfs(Path(kwargs["filename"]).parent)
available_space = stats.f_bavail * stats.f_frsize # bytes
if required_space >= available_space * 0.8:
raise RuntimeError(
f"You're about to take up {required_space} of {available_space} bytes available."
)
return np.memmap(**kwargs)
class OnlineBuffer(torch.utils.data.Dataset):
"""FIFO data buffer for the online training loop in train.py.
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
loop in the same way that a LeRobotDataset would be used.
The underlying data structure will have data inserted in a circular fashion. Always insert after the
last index, and when you reach the end, wrap around to the start.
The data is stored in a numpy memmap.
"""
NEXT_INDEX_KEY = "_next_index"
OCCUPANCY_MASK_KEY = "_occupancy_mask"
INDEX_KEY = "index"
FRAME_INDEX_KEY = "frame_index"
EPISODE_INDEX_KEY = "episode_index"
TIMESTAMP_KEY = "timestamp"
IS_PAD_POSTFIX = "_is_pad"
def __init__(
self,
write_dir: str | Path,
data_spec: dict[str, Any] | None,
buffer_capacity: int | None,
fps: float | None = None,
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
):
"""
The online buffer can be provided from scratch or you can load an existing online buffer by passing
a `write_dir` associated with an existing buffer.
Args:
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
Note that if the files already exist, they are opened in read-write mode (used for training
resumption.)
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
but note that "index", "frame_index" and "episode_index" are already accounted for by this
class, so you don't need to include them.
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
system's available disk space when choosing this.
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
delta_timestamps logic. You can pass None if you are not using delta_timestamps.
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
converted to dict[str, np.ndarray] for optimization purposes.
"""
self.set_delta_timestamps(delta_timestamps)
self._fps = fps
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
# the requested frames. It is only used when `delta_timestamps` is provided.
# minus 1e-4 to account for possible numerical error
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
self._buffer_capacity = buffer_capacity
data_spec = self._make_data_spec(data_spec, buffer_capacity)
Path(write_dir).mkdir(parents=True, exist_ok=True)
self._data = {}
for k, v in data_spec.items():
self._data[k] = _make_memmap_safe(
filename=Path(write_dir) / k,
dtype=v["dtype"] if v is not None else None,
mode="r+" if (Path(write_dir) / k).exists() else "w+",
shape=tuple(v["shape"]) if v is not None else None,
)
@property
def delta_timestamps(self) -> dict[str, np.ndarray] | None:
return self._delta_timestamps
def set_delta_timestamps(self, value: dict[str, list[float]] | None):
"""Set delta_timestamps converting the values to numpy arrays.
The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
need to be converted into numpy arrays.
"""
if value is not None:
self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
else:
self._delta_timestamps = None
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
"""Makes the data spec for np.memmap."""
if any(k.startswith("_") for k in data_spec):
raise ValueError(
"data_spec keys should not start with '_'. This prefix is reserved for internal logic."
)
preset_keys = {
OnlineBuffer.INDEX_KEY,
OnlineBuffer.FRAME_INDEX_KEY,
OnlineBuffer.EPISODE_INDEX_KEY,
OnlineBuffer.TIMESTAMP_KEY,
}
if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
raise ValueError(
f"data_spec should not contain any of {preset_keys} as these are handled internally. "
f"The provided data_spec has {intersection}."
)
complete_data_spec = {
# _next_index will be a pointer to the next index that we should start filling from when we add
# more data.
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
# with real data rather than the dummy initialization.
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
}
for k, v in data_spec.items():
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
return complete_data_spec
def add_data(self, data: dict[str, np.ndarray]):
"""Add new data to the buffer, which could potentially mean shifting old data out.
The new data should contain all the frames (in order) of any number of episodes. The indices should
start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
`eval_policy` functions in `eval.py` for more information on how the data is constructed.
Shift the incoming data index and episode_index to continue on from the last frame. Note that this
will be done in place!
"""
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
raise ValueError(f"Missing data keys: {missing_keys}")
new_data_length = len(data[self.data_keys[0]])
if not all(len(data[k]) == new_data_length for k in self.data_keys):
raise ValueError("All data items should have the same length")
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
# Sanity check to make sure that the new data indices start from 0.
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
# Shift the incoming indices if necessary.
if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
# Insert the new data starting from next_index. It may be necessary to wrap around to the start.
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
for k in self.data_keys:
if n_surplus == 0:
slc = slice(next_index, next_index + new_data_length)
self._data[k][slc] = data[k]
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
else:
self._data[k][next_index:] = data[k][:-n_surplus]
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
self._data[k][:n_surplus] = data[k][-n_surplus:]
if n_surplus == 0:
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
else:
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
@property
def data_keys(self) -> list[str]:
keys = set(self._data)
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
return sorted(keys)
@property
def fps(self) -> float | None:
return self._fps
@property
def num_episodes(self) -> int:
return len(
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
)
@property
def num_frames(self) -> int:
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
def __len__(self):
return self.num_frames
def _item_to_tensors(self, item: dict) -> dict:
item_ = {}
for k, v in item.items():
if isinstance(v, torch.Tensor):
item_[k] = v
elif isinstance(v, np.ndarray):
item_[k] = torch.from_numpy(v)
else:
item_[k] = torch.tensor(v)
return item_
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self) or idx < -len(self):
raise IndexError
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
if self.delta_timestamps is None:
return self._item_to_tensors(item)
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
episode_data_indices = np.where(
np.bitwise_and(
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
)
)[0]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
for data_key in self.delta_timestamps:
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
# Get timestamps used as query to retrieve data of previous/future frames.
query_ts = current_ts + self.delta_timestamps[data_key]
# Compute distances between each query timestamp and all timestamps of all the frames belonging to
# the episode.
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
argmin_ = np.argmin(dist, axis=1)
min_ = dist[np.arange(dist.shape[0]), argmin_]
is_pad = min_ > self.tolerance_s
# Check violated query timestamps are all outside the episode range.
assert (
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
") inside the episode range."
)
# Load frames for this data key.
item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
return self._item_to_tensors(item)
def get_data_by_key(self, key: str) -> torch.Tensor:
"""Returns all data for a given data key as a Tensor."""
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
def compute_sampler_weights(
offline_dataset: LeRobotDataset,
offline_drop_n_last_frames: int = 0,
online_dataset: OnlineBuffer | None = None,
online_sampling_ratio: float | None = None,
online_drop_n_last_frames: int = 0,
) -> torch.Tensor:
"""Compute the sampling weights for the online training dataloader in train.py.
Args:
offline_dataset: The LeRobotDataset used for offline pre-training.
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
online_dataset: The OnlineBuffer used in online training.
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
online dataset is provided, this value must also be provided.
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
dataset.
Returns:
Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
Notes to maintainers:
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
is the ability to turn shuffling off.
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
included here to avoid adding complexity.
"""
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
if (online_dataset is None) ^ (online_sampling_ratio is None):
raise ValueError(
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
)
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
weights = []
if len(offline_dataset) > 0:
offline_data_mask_indices = []
for start_index, end_index in zip(
offline_dataset.episode_data_index["from"],
offline_dataset.episode_data_index["to"],
strict=True,
):
offline_data_mask_indices.extend(
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
)
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
weights.append(
torch.full(
size=(len(offline_dataset),),
fill_value=offline_sampling_ratio / offline_data_mask.sum(),
)
* offline_data_mask
)
if online_dataset is not None and len(online_dataset) > 0:
online_data_mask_indices = []
episode_indices = online_dataset.get_data_by_key("episode_index")
for episode_idx in torch.unique(episode_indices):
where_episode = torch.where(episode_indices == episode_idx)
start_index = where_episode[0][0]
end_index = where_episode[0][-1] + 1
online_data_mask_indices.extend(
range(start_index.item(), end_index.item() - online_drop_n_last_frames)
)
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
online_data_mask[torch.tensor(online_data_mask_indices)] = True
weights.append(
torch.full(
size=(len(online_dataset),),
fill_value=online_sampling_ratio / online_data_mask.sum(),
)
* online_data_mask
)
weights = torch.cat(weights)
if weights.sum() == 0:
weights += 1 / len(weights)
else:
weights /= weights.sum()
return weights

View File

@@ -0,0 +1,131 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict
import datasets
import numpy
import PIL
import torch
from lerobot.datasets.video_utils import encode_video_frames
def concatenate_episodes(ep_dicts):
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
def save_image(img_array, i, out_dir):
img = PIL.Image.fromarray(img_array)
img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
def get_default_encoding() -> dict:
"""Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
signature = inspect.signature(encode_video_frames)
return {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
}
def check_repo_id(repo_id: str) -> None:
if len(repo_id.split("/")) != 2:
raise ValueError(
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
)
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
Parameters:
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
Returns:
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
- "from": A tensor containing the starting index of each episode.
- "to": A tensor containing the ending index of each episode.
"""
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
}
"""
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list
episode_data_index["from"].append(idx)
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
if current_episode is not None:
episode_data_index["to"].append(idx)
# Let's keep track of the current episode index
current_episode = episode_idx
else:
# We are still in the same episode, so there is nothing for us to do here
pass
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
episode_data_index["to"].append(idx + 1)
for k in ["from", "to"]:
episode_data_index[k] = torch.tensor(episode_data_index[k])
return episode_data_index

View File

@@ -0,0 +1,61 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterator, Union
import torch
class EpisodeAwareSampler:
def __init__(
self,
episode_data_index: dict,
episode_indices_to_use: Union[list, None] = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
):
"""Sampler that optionally incorporates episode boundary information.
Args:
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
"""
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
)
self.indices = indices
self.shuffle = shuffle
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for i in self.indices:
yield i
def __len__(self) -> int:
return len(self.indices)

View File

@@ -0,0 +1,249 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from dataclasses import dataclass, field
from typing import Any, Callable, Sequence
import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2 import functional as F # noqa: N812
class RandomSubsetApply(Transform):
"""Apply a random subset of N transformations from a list of transformations.
Args:
transforms: list of transformations.
p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
have the same probability.
n_subset: number of transformations to apply. If ``None``, all transforms are applied.
Must be in [1, len(transforms)].
random_order: apply transformations in a random order.
"""
def __init__(
self,
transforms: Sequence[Callable],
p: list[float] | None = None,
n_subset: int | None = None,
random_order: bool = False,
) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
)
if n_subset is None:
n_subset = len(transforms)
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
self.n_subset = n_subset
self.random_order = random_order
self.selected_transforms = None
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
if not self.random_order:
selected_indices = selected_indices.sort().values
self.selected_transforms = [self.transforms[i] for i in selected_indices]
for transform in self.selected_transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str:
return (
f"transforms={self.transforms}, "
f"p={self.p}, "
f"n_subset={self.n_subset}, "
f"random_order={self.random_order}"
)
class SharpnessJitter(Transform):
"""Randomly change the sharpness of an image or video.
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image,
SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
augmentations as a result.
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
by a factor of 2.
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
[max(0, 1 - sharpness), 1 + sharpness] or the given
[min, max]. Should be non negative numbers.
"""
def __init__(self, sharpness: float | Sequence[float]) -> None:
super().__init__()
self.sharpness = self._check_input(sharpness)
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness]
else:
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(f"sharpness values should be between (0., inf), but got {sharpness}.")
return float(sharpness[0]), float(sharpness[1])
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
return {"sharpness_factor": sharpness_factor}
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
sharpness_factor = params["sharpness_factor"]
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
@dataclass
class ImageTransformConfig:
"""
For each transform, the following parameters are available:
weight: This represents the multinomial probability (with no replacement)
used for sampling the transform. If the sum of the weights is not 1,
they will be normalized.
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
custom transform defined here.
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
(following uniform distribution) when it's applied.
"""
weight: float = 1.0
type: str = "Identity"
kwargs: dict[str, Any] = field(default_factory=dict)
@dataclass
class ImageTransformsConfig:
"""
These transforms are all using standard torchvision.transforms.v2
You can find out how these transformations affect images here:
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
We use a custom RandomSubsetApply container to sample them.
"""
# Set this flag to `true` to enable transforms during training
enable: bool = False
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [1, number_of_available_transforms].
max_num_transforms: int = 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: bool = False
tfs: dict[str, ImageTransformConfig] = field(
default_factory=lambda: {
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.8, 1.2)},
),
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.8, 1.2)},
),
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 1.5)},
),
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (-0.05, 0.05)},
),
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
}
)
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration."""
def __init__(self, cfg: ImageTransformsConfig) -> None:
super().__init__()
self._cfg = cfg
self.weights = []
self.transforms = {}
for tf_name, tf_cfg in cfg.tfs.items():
if tf_cfg.weight <= 0.0:
continue
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
self.weights.append(tf_cfg.weight)
n_subset = min(len(self.transforms), cfg.max_num_transforms)
if n_subset == 0 or not cfg.enable:
self.tf = v2.Identity()
else:
self.tf = RandomSubsetApply(
transforms=list(self.transforms.values()),
p=self.weights,
n_subset=n_subset,
random_order=cfg.random_order,
)
def forward(self, *inputs: Any) -> Any:
return self.tf(*inputs)

View File

@@ -0,0 +1,860 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import importlib.resources
import json
import logging
from collections.abc import Iterator
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from types import SimpleNamespace
from typing import Any
import datasets
import jsonlines
import numpy as np
import packaging.version
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
from lerobot.datasets.backward_compatibility import (
V21_MESSAGE,
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.robots import Robot
from lerobot.utils.utils import is_valid_numpy_dtype_string
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DATASET_CARD_TEMPLATE = """
---
# Metadata will go there
---
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
## {}
"""
DEFAULT_FEATURES = {
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
"index": {"dtype": "int64", "shape": (1,), "names": None},
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
```
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
>>> print(flatten_dict(dct))
{"a/b": 1, "a/c/d": 2, "e": 3}
"""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def unflatten_dict(d: dict, sep: str = "/") -> dict:
outdict = {}
for key, value in d.items():
parts = key.split(sep)
d = outdict
for part in parts[:-1]:
if part not in d:
d[part] = {}
d = d[part]
d[parts[-1]] = value
return outdict
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
split_keys = flattened_key.split(sep)
getter = obj[split_keys[0]]
if len(split_keys) == 1:
return getter
for key in split_keys[1:]:
getter = getter[key]
return getter
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
serialized_dict[key] = value.tolist()
elif isinstance(value, np.generic):
serialized_dict[key] = value.item()
elif isinstance(value, (int, float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
return unflatten_dict(serialized_dict)
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
return dataset
def load_json(fpath: Path) -> Any:
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def load_jsonlines(fpath: Path) -> list[Any]:
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def append_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "a") as writer:
writer.write(data)
def write_info(info: dict, local_dir: Path):
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict:
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
return info
def write_stats(stats: dict, local_dir: Path):
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
return cast_stats_to_numpy(stats)
def write_task(task_index: int, task: dict, local_dir: Path):
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, local_dir / TASKS_PATH)
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
def write_episode(episode: dict, local_dir: Path):
append_jsonlines(episode, local_dir / EPISODES_PATH)
def load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
def load_episodes_stats(local_dir: Path) -> dict:
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
return {
item["episode_index"]: cast_stats_to_numpy(item["stats"])
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
}
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[str, dict[str, np.ndarray]]:
return dict.fromkeys(episodes, stats)
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1))
if np.issubdtype(dtype, np.floating):
img_array /= 255.0
return img_array
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
with channel first (c h w) of float32 type in range [0,1].
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict
def is_valid_version(version: str) -> bool:
try:
packaging.version.parse(version)
return True
except packaging.version.InvalidVersion:
return False
def check_version_compatibility(
repo_id: str,
version_to_check: str | packaging.version.Version,
current_version: str | packaging.version.Version,
enforce_breaking_major: bool = True,
) -> None:
v_check = (
packaging.version.parse(version_to_check)
if not isinstance(version_to_check, packaging.version.Version)
else version_to_check
)
v_current = (
packaging.version.parse(current_version)
if not isinstance(current_version, packaging.version.Version)
else current_version
)
if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor:
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
"""Returns available valid versions (branches and tags) on given repo."""
api = HfApi()
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
repo_versions = []
for ref in repo_refs:
with contextlib.suppress(packaging.version.InvalidVersion):
repo_versions.append(packaging.version.parse(ref))
return repo_versions
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
"""
Returns the version if available on repo or the latest compatible one.
Otherwise, will throw a `CompatibilityError`.
"""
target_version = (
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
)
hub_versions = get_repo_versions(repo_id)
if not hub_versions:
raise RevisionNotFoundError(
f"""Your dataset must be tagged with a codebase version.
Assuming _version_ is the codebase_version value in the info.json, you can run this:
```python
from huggingface_hub import HfApi
hub_api = HfApi()
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
```
"""
)
if target_version in hub_versions:
return f"v{target_version}"
compatibles = [
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
]
if compatibles:
return_version = max(compatibles)
if return_version < target_version:
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major]
if lower_major:
raise BackwardCompatibilityError(repo_id, max(lower_major))
upper_versions = [v for v in hub_versions if v > target_version]
assert len(upper_versions) > 0
raise ForwardCompatibilityError(repo_id, min(upper_versions))
def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"])
elif len(ft["shape"]) == 1:
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
elif len(ft["shape"]) == 2:
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 3:
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 4:
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 5:
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
else:
raise ValueError(f"Corresponding feature is not valid: {ft}")
return datasets.Features(hf_features)
def _validate_feature_names(features: dict[str, dict]) -> None:
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
if invalid_features:
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
def hw_to_dataset_features(
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
) -> dict[str, dict]:
features = {}
joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == "action":
features[prefix] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
if joint_fts and prefix == "observation":
features[f"{prefix}.state"] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
for key, shape in cam_fts.items():
features[f"{prefix}.images.{key}"] = {
"dtype": "video" if use_video else "image",
"shape": shape,
"names": ["height", "width", "channels"],
}
_validate_feature_names(features)
return features
def build_dataset_frame(
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
) -> dict[str, np.ndarray]:
frame = {}
for key, ft in ds_features.items():
if key in DEFAULT_FEATURES or not key.startswith(prefix):
continue
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
return frame
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
camera_ft = {}
if robot.cameras:
camera_ft = {
key: {"dtype": "video" if use_videos else "image", **ft}
for key, ft in robot.camera_features.items()
}
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
shape = ft["shape"]
if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state":
type = FeatureType.ENV
elif key.startswith("observation"):
type = FeatureType.STATE
elif key.startswith("action"):
type = FeatureType.ACTION
else:
continue
policy_features[key] = PolicyFeature(
type=type,
shape=shape,
)
return policy_features
def create_empty_dataset_info(
codebase_version: str,
fps: int,
features: dict,
use_videos: bool,
robot_type: str | None = None,
) -> dict:
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": 0,
"total_frames": 0,
"total_tasks": 0,
"total_videos": 0,
"total_chunks": 0,
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": fps,
"splits": {},
"data_path": DEFAULT_PARQUET_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features,
}
def get_episode_data_index(
episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
cumulative_lengths = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
"to": torch.LongTensor(cumulative_lengths),
}
def check_timestamps_sync(
timestamps: np.ndarray,
episode_indices: np.ndarray,
episode_data_index: dict[str, np.ndarray],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
to account for possible numerical error.
Args:
timestamps (np.ndarray): Array of timestamps in seconds.
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
which identifies indices for the end of each episode.
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
raise_value_error (bool): Whether to raise a ValueError if the check fails.
Returns:
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
Raises:
ValueError: If the check fails and `raise_value_error` is True.
"""
if timestamps.shape != episode_indices.shape:
raise ValueError(
"timestamps and episode_indices should have the same shape. "
f"Found {timestamps.shape=} and {episode_indices.shape=}."
)
# Consecutive differences
diffs = np.diff(timestamps)
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
# Mask to ignore differences at the boundaries between episodes
mask = np.ones(len(diffs), dtype=bool)
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
# Check if all remaining diffs are within tolerance
if not np.all(filtered_within_tolerance):
# Track original indices before masking
original_indices = np.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
outside_tolerances = []
for idx in outside_tolerance_indices:
entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx],
"episode_index": episode_indices[idx].item()
if hasattr(episode_indices[idx], "item")
else episode_indices[idx],
}
outside_tolerances.append(entry)
if raise_value_error:
raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
This might be due to synchronization issues during data collection.
\n{pformat(outside_tolerances)}"""
)
return False
return True
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
actual timestamps from the dataset.
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
if not all(within_tolerance):
outside_tolerance[key] = [
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
]
if len(outside_tolerance) > 0:
if raise_value_error:
raise ValueError(
f"""
The following delta_timestamps are found outside of tolerance range.
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
their values accordingly.
\n{pformat(outside_tolerance)}
"""
)
return False
return True
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts]
return delta_indices
def cycle(iterable):
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
"""
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it.
"""
api = HfApi()
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
refs = [branch.ref for branch in branches]
ref = f"refs/heads/{branch}"
if ref in refs:
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
def create_lerobot_dataset_card(
tags: list | None = None,
dataset_info: dict | None = None,
**kwargs,
) -> DatasetCard:
"""
Keyword arguments will be used to replace values in src/lerobot/datasets/card_template.md.
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
"""
card_tags = ["LeRobot"]
if tags:
card_tags += tags
if dataset_info:
dataset_structure = "[meta/info.json](meta/info.json):\n"
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
kwargs = {**kwargs, "dataset_structure": dataset_structure}
card_data = DatasetCardData(
license=kwargs.get("license"),
tags=card_tags,
task_categories=["robotics"],
configs=[
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
],
)
card_template = (importlib.resources.files("lerobot.datasets") / "card_template.md").read_text()
return DatasetCard.from_template(
card_data=card_data,
template_str=card_template,
**kwargs,
)
class IterableNamespace(SimpleNamespace):
"""
A namespace object that supports both dictionary-like iteration and dot notation access.
Automatically converts nested dictionaries into IterableNamespaces.
This class extends SimpleNamespace to provide:
- Dictionary-style iteration over keys
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
- Dictionary-like methods: items(), keys(), values()
- Recursive conversion of nested dictionaries
Args:
dictionary: Optional dictionary to initialize the namespace
**kwargs: Additional keyword arguments passed to SimpleNamespace
Examples:
>>> data = {"name": "Alice", "details": {"age": 25}}
>>> ns = IterableNamespace(data)
>>> ns.name
'Alice'
>>> ns.details.age
25
>>> list(ns.keys())
['name', 'details']
>>> for key, value in ns.items():
... print(f"{key}: {value}")
name: Alice
details: IterableNamespace(age=25)
"""
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
super().__init__(**kwargs)
if dictionary is not None:
for key, value in dictionary.items():
if isinstance(value, dict):
setattr(self, key, IterableNamespace(value))
else:
setattr(self, key, value)
def __iter__(self) -> Iterator[str]:
return iter(vars(self))
def __getitem__(self, key: str) -> Any:
return vars(self)[key]
def items(self):
return vars(self).items()
def values(self):
return vars(self).values()
def keys(self):
return vars(self).keys()
def validate_frame(frame: dict, features: dict):
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)
error_message = validate_features_presence(actual_features, expected_features)
common_features = actual_features & expected_features
for name in common_features - {"task"}:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - expected_features
if missing_features or extra_features:
error_message += "Feature mismatch in `frame` dictionary:\n"
if missing_features:
error_message += f"Missing features: {missing_features}\n"
if extra_features:
error_message += f"Extra features: {extra_features}\n"
return error_message
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
):
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
actual_shape = value.shape
if actual_dtype != np.dtype(expected_dtype):
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
if actual_shape != expected_shape:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
else:
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_string(name: str, value: str):
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
if "task" not in episode_buffer:
raise ValueError("task key not found in episode_buffer")
if episode_buffer["episode_index"] != total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_buffer["size"] == 0:
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `features`."
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)

View File

@@ -0,0 +1,884 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
lerobot/configs/robot/aloha.yaml before running this script.
"""
import traceback
from pathlib import Path
from textwrap import dedent
from lerobot import available_datasets
from lerobot.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
from lerobot.robots.aloha.configuration_aloha import AlohaRobotConfig
LOCAL_DIR = Path("data/")
# spellchecker:off
ALOHA_MOBILE_INFO = {
"robot_config": AlohaRobotConfig(),
"license": "mit",
"url": "https://mobile-aloha.github.io/",
"paper": "https://huggingface.co/papers/2401.02117",
"citation_bibtex": dedent(r"""
@inproceedings{fu2024mobile,
author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
booktitle = {arXiv},
year = {2024},
}""").lstrip(),
}
ALOHA_STATIC_INFO = {
"robot_config": AlohaRobotConfig(),
"license": "mit",
"url": "https://tonyzhaozh.github.io/aloha/",
"paper": "https://huggingface.co/papers/2304.13705",
"citation_bibtex": dedent(r"""
@article{Zhao2023LearningFB,
title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
journal={RSS},
year={2023},
volume={abs/2304.13705},
url={https://huggingface.co/papers/2304.13705}
}""").lstrip(),
}
PUSHT_INFO = {
"license": "mit",
"url": "https://diffusion-policy.cs.columbia.edu/",
"paper": "https://huggingface.co/papers/2303.04137",
"citation_bibtex": dedent(r"""
@article{chi2024diffusionpolicy,
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
journal = {The International Journal of Robotics Research},
year = {2024},
}""").lstrip(),
}
XARM_INFO = {
"license": "mit",
"url": "https://www.nicklashansen.com/td-mpc/",
"paper": "https://huggingface.co/papers/2203.04955",
"citation_bibtex": dedent(r"""
@inproceedings{Hansen2022tdmpc,
title={Temporal Difference Learning for Model Predictive Control},
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
booktitle={ICML},
year={2022}
}
"""),
}
UNITREEH_INFO = {
"license": "apache-2.0",
}
DATASETS = {
"aloha_mobile_cabinet": {
"single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_chair": {
"single_task": "Push the chairs in front of the desk to place them against it.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_elevator": {
"single_task": "Take the elevator to the 1st floor.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_shrimp": {
"single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_wash_pan": {
"single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_wipe_wine": {
"single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
**ALOHA_MOBILE_INFO,
},
"aloha_static_battery": {
"single_task": "Place the battery into the slot of the remote controller.",
**ALOHA_STATIC_INFO,
},
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
"aloha_static_coffee": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee_new": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
**ALOHA_STATIC_INFO,
},
"aloha_static_cups_open": {
"single_task": "Pick up the plastic cup and open its lid.",
**ALOHA_STATIC_INFO,
},
"aloha_static_fork_pick_up": {
"single_task": "Pick up the fork and place it on the plate.",
**ALOHA_STATIC_INFO,
},
"aloha_static_pingpong_test": {
"single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
**ALOHA_STATIC_INFO,
},
"aloha_static_pro_pencil": {
"single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
**ALOHA_STATIC_INFO,
},
"aloha_static_screw_driver": {
"single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
**ALOHA_STATIC_INFO,
},
"aloha_static_tape": {
"single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
**ALOHA_STATIC_INFO,
},
"aloha_static_thread_velcro": {
"single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_towel": {
"single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup": {
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup_left": {
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_human_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_scripted": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_scripted_image": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_human": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_human_image": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
"unitreeh1_two_robot_greeting": {
"single_task": "Greet the other robot with a high five.",
**UNITREEH_INFO,
},
"unitreeh1_warehouse": {
"single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
**UNITREEH_INFO,
},
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"umi_cup_in_the_wild": {
"single_task": "Put the cup on the plate.",
"license": "apache-2.0",
},
"asu_table_top": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
"citation_bibtex": dedent(r"""
@inproceedings{zhou2023modularity,
title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
booktitle={Conference on Robot Learning},
pages={1684--1695},
year={2023},
organization={PMLR}
}
@article{zhou2023learning,
title={Learning modular language-conditioned robot policies through attention},
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
journal={Autonomous Robots},
pages={1--21},
year={2023},
publisher={Springer}
}""").lstrip(),
},
"austin_buds_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/BUDS-website/",
"paper": "https://huggingface.co/papers/2109.13841",
"citation_bibtex": dedent(r"""
@article{zhu2022bottom,
title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
journal={IEEE Robotics and Automation Letters},
volume={7},
number={2},
pages={4126--4133},
year={2022},
publisher={IEEE}
}""").lstrip(),
},
"austin_sailor_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/sailor/",
"paper": "https://huggingface.co/papers/2210.11435",
"citation_bibtex": dedent(r"""
@inproceedings{nasiriany2022sailor,
title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
booktitle={Conference on Robot Learning (CoRL)},
year={2022}
}""").lstrip(),
},
"austin_sirius_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/sirius/",
"paper": "https://huggingface.co/papers/2211.08416",
"citation_bibtex": dedent(r"""
@inproceedings{liu2022robot,
title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
booktitle = {Robotics: Science and Systems (RSS)},
year = {2023}
}""").lstrip(),
},
"berkeley_autolab_ur5": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://sites.google.com/view/berkeley-ur5/home",
"citation_bibtex": dedent(r"""
@misc{BerkeleyUR5Website,
title = {Berkeley {UR5} Demonstration Dataset},
author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
howpublished = {https://sites.google.com/view/berkeley-ur5/home},
}""").lstrip(),
},
"berkeley_cable_routing": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://sites.google.com/view/cablerouting/home",
"paper": "https://huggingface.co/papers/2307.08927",
"citation_bibtex": dedent(r"""
@article{luo2023multistage,
author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
journal = {arXiv pre-print},
year = {2023},
url = {https://huggingface.co/papers/2307.08927},
}""").lstrip(),
},
"berkeley_fanuc_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
"citation_bibtex": dedent(r"""
@article{fanuc_manipulation2023,
title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
year={2023},
}""").lstrip(),
},
"berkeley_gnm_cory_hall": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://huggingface.co/papers/1709.10489",
"citation_bibtex": dedent(r"""
@inproceedings{kahn2018self,
title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
pages={5129--5136},
year={2018},
organization={IEEE}
}""").lstrip(),
},
"berkeley_gnm_recon": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/recon-robot",
"paper": "https://huggingface.co/papers/2104.05859",
"citation_bibtex": dedent(r"""
@inproceedings{shah2021rapid,
title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
booktitle={5th Annual Conference on Robot Learning },
year={2021},
url={https://openreview.net/forum?id=d_SWJhyKfVw}
}""").lstrip(),
},
"berkeley_gnm_sac_son": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/SACSoN-review",
"paper": "https://huggingface.co/papers/2306.01874",
"citation_bibtex": dedent(r"""
@article{hirose2023sacson,
title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
journal={arXiv preprint arXiv:2306.01874},
year={2023}
}""").lstrip(),
},
"berkeley_mvp": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://huggingface.co/papers/2203.06173",
"citation_bibtex": dedent(r"""
@InProceedings{Radosavovic2022,
title = {Real-World Robot Learning with Masked Visual Pre-training},
author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
booktitle = {CoRL},
year = {2022}
}""").lstrip(),
},
"berkeley_rpt": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://huggingface.co/papers/2306.10007",
"citation_bibtex": dedent(r"""
@article{Radosavovic2023,
title={Robot Learning with Sensorimotor Pre-training},
author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
year={2023},
journal={arXiv:2306.10007}
}""").lstrip(),
},
"cmu_franka_exploration_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://human-world-model.github.io/",
"paper": "https://huggingface.co/papers/2308.10901",
"citation_bibtex": dedent(r"""
@inproceedings{mendonca2023structured,
title={Structured World Models from Human Videos},
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
journal={RSS},
year={2023}
}""").lstrip(),
},
"cmu_play_fusion": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://play-fusion.github.io/",
"paper": "https://huggingface.co/papers/2312.04549",
"citation_bibtex": dedent(r"""
@inproceedings{chen2023playfusion,
title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
booktitle={CoRL},
year={2023}
}""").lstrip(),
},
"cmu_stretch": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://robo-affordances.github.io/",
"paper": "https://huggingface.co/papers/2304.08488",
"citation_bibtex": dedent(r"""
@inproceedings{bahl2023affordances,
title={Affordances from Human Videos as a Versatile Representation for Robotics},
author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
booktitle={CVPR},
year={2023}
}
@article{mendonca2023structured,
title={Structured World Models from Human Videos},
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
journal={CoRL},
year={2023}
}""").lstrip(),
},
"columbia_cairlab_pusht_real": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://diffusion-policy.cs.columbia.edu/",
"paper": "https://huggingface.co/papers/2303.04137",
"citation_bibtex": dedent(r"""
@inproceedings{chi2023diffusionpolicy,
title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
booktitle={Proceedings of Robotics: Science and Systems (RSS)},
year={2023}
}""").lstrip(),
},
"conq_hose_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
"citation_bibtex": dedent(r"""
@misc{ConqHoseManipData,
author={Peter Mitrano and Dmitry Berenson},
title={Conq Hose Manipulation Dataset, v1.15.0},
year={2024},
howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
}""").lstrip(),
},
"dlr_edan_shared_control": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://ieeexplore.ieee.org/document/9341156",
"citation_bibtex": dedent(r"""
@inproceedings{vogel_edan_2020,
title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
language = {en},
booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
year = {2020}
}
@inproceedings{quere_shared_2020,
address = {Paris, France},
title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
language = {en},
booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
year = {2020},
pages = {7},
}""").lstrip(),
},
"dlr_sara_grid_clamp": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://www.researchsquare.com/article/rs-3289569/v1",
"citation_bibtex": dedent(r"""
@article{padalkar2023guided,
title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
journal={Research square preprint rs-3289569/v1},
year={2023}
}""").lstrip(),
},
"dlr_sara_pour": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
"citation_bibtex": dedent(r"""
@inproceedings{padalkar2023guiding,
title={Guiding Reinforcement Learning with Shared Control Templates},
author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
year={2023},
organization={IEEE}
}""").lstrip(),
},
"droid_100": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://droid-dataset.github.io/",
"paper": "https://huggingface.co/papers/2403.12945",
"citation_bibtex": dedent(r"""
@article{khazatsky2024droid,
title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
year = {2024},
}""").lstrip(),
},
"fmb": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://functional-manipulation-benchmark.github.io/",
"paper": "https://huggingface.co/papers/2401.08553",
"citation_bibtex": dedent(r"""
@article{luo2024fmb,
title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
journal={arXiv preprint arXiv:2401.08553},
year={2024}
}""").lstrip(),
},
"iamlab_cmu_pickup_insert": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://openreview.net/forum?id=WuBv9-IGDUA",
"paper": "https://huggingface.co/papers/2401.14502",
"citation_bibtex": dedent(r"""
@inproceedings{saxena2023multiresolution,
title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
booktitle={7th Annual Conference on Robot Learning},
year={2023},
url={https://openreview.net/forum?id=WuBv9-IGDUA}
}""").lstrip(),
},
"imperialcollege_sawyer_wrist_cam": {
"tasks_col": "language_instruction",
"license": "mit",
},
"jaco_play": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://github.com/clvrai/clvr_jaco_play_dataset",
"citation_bibtex": dedent(r"""
@software{dass2023jacoplay,
author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
title = {CLVR Jaco Play Dataset},
url = {https://github.com/clvrai/clvr_jaco_play_dataset},
version = {1.0.0},
year = {2023}
}""").lstrip(),
},
"kaist_nonprehensile": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
"citation_bibtex": dedent(r"""
@article{kimpre,
title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
year={2023},
organization={IEEE}
}""").lstrip(),
},
"nyu_door_opening_surprising_effectiveness": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://jyopari.github.io/VINN/",
"paper": "https://huggingface.co/papers/2112.01511",
"citation_bibtex": dedent(r"""
@misc{pari2021surprising,
title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
year={2021},
eprint={2112.01511},
archivePrefix={arXiv},
primaryClass={cs.RO}
}""").lstrip(),
},
"nyu_franka_play_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://play-to-policy.github.io/",
"paper": "https://huggingface.co/papers/2210.10047",
"citation_bibtex": dedent(r"""
@article{cui2022play,
title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
journal = {arXiv preprint arXiv:2210.10047},
year = {2022}
}""").lstrip(),
},
"nyu_rot_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://rot-robot.github.io/",
"paper": "https://huggingface.co/papers/2206.15469",
"citation_bibtex": dedent(r"""
@inproceedings{haldar2023watch,
title={Watch and match: Supercharging imitation with regularized optimal transport},
author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
booktitle={Conference on Robot Learning},
pages={32--43},
year={2023},
organization={PMLR}
}""").lstrip(),
},
"roboturk": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://roboturk.stanford.edu/dataset_real.html",
"paper": "PAPER",
"citation_bibtex": dedent(r"""
@inproceedings{mandlekar2019scaling,
title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
pages={1048--1055},
year={2019},
organization={IEEE}
}""").lstrip(),
},
"stanford_hydra_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/hydra-il-2023",
"paper": "https://huggingface.co/papers/2306.17237",
"citation_bibtex": dedent(r"""
@article{belkhale2023hydra,
title={HYDRA: Hybrid Robot Actions for Imitation Learning},
author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
journal={arxiv},
year={2023}
}""").lstrip(),
},
"stanford_kuka_multimodal_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/visionandtouch",
"paper": "https://huggingface.co/papers/1810.10191",
"citation_bibtex": dedent(r"""
@inproceedings{lee2019icra,
title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
year={2019},
url={https://huggingface.co/papers/1810.10191}
}""").lstrip(),
},
"stanford_robocook": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://hshi74.github.io/robocook/",
"paper": "https://huggingface.co/papers/2306.14447",
"citation_bibtex": dedent(r"""
@article{shi2023robocook,
title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
journal={arXiv preprint arXiv:2306.14447},
year={2023}
}""").lstrip(),
},
"taco_play": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
"paper": "https://huggingface.co/papers/2209.08959, https://huggingface.co/papers/2210.01911",
"citation_bibtex": dedent(r"""
@inproceedings{rosete2022tacorl,
author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
year = {2022}
}
@inproceedings{mees23hulc2,
title={Grounding Language with Visual Affordances over Unstructured Data},
author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
year={2023},
address = {London, UK}
}""").lstrip(),
},
"tokyo_u_lsmo": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "URL",
"paper": "https://huggingface.co/papers/2107.05842",
"citation_bibtex": dedent(r"""
@Article{Osa22,
author = {Takayuki Osa},
journal = {The International Journal of Robotics Research},
title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
year = {2022},
number = {3},
pages = {291--311},
volume = {41},
}""").lstrip(),
},
"toto": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://toto-benchmark.org/",
"paper": "https://huggingface.co/papers/2306.00942",
"citation_bibtex": dedent(r"""
@inproceedings{zhou2023train,
author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
title={Train Offline, Test Online: A Real Robot Learning Benchmark},
year={2023},
}""").lstrip(),
},
"ucsd_kitchen_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@ARTICLE{ucsd_kitchens,
author = {Ge Yan, Kris Wu, and Xiaolong Wang},
title = {{ucsd kitchens Dataset}},
year = {2023},
month = {August}
}""").lstrip(),
},
"ucsd_pick_and_place_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://owmcorl.github.io/#",
"paper": "https://huggingface.co/papers/2310.16029",
"citation_bibtex": dedent(r"""
@preprint{Feng2023Finetuning,
title={Finetuning Offline World Models in the Real World},
author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
year={2023}
}""").lstrip(),
},
"uiuc_d3field": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://robopil.github.io/d3fields/",
"paper": "https://huggingface.co/papers/2309.16118",
"citation_bibtex": dedent(r"""
@article{wang2023d3field,
title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
journal={arXiv preprint arXiv:},
year={2023},
}""").lstrip(),
},
"usc_cloth_sim": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://uscresl.github.io/dmfd/",
"paper": "https://huggingface.co/papers/2207.10148",
"citation_bibtex": dedent(r"""
@article{salhotra2022dmfd,
author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
journal={IEEE Robotics and Automation Letters},
title={Learning Deformable Object Manipulation From Expert Demonstrations},
year={2022},
volume={7},
number={4},
pages={8775-8782},
doi={10.1109/LRA.2022.3187843}
}""").lstrip(),
},
"utaustin_mutex": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/MUTEX/",
"paper": "https://huggingface.co/papers/2309.14320",
"citation_bibtex": dedent(r"""
@inproceedings{shah2023mutex,
title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
booktitle={7th Annual Conference on Robot Learning},
year={2023},
url={https://openreview.net/forum?id=PwqiqaaEzJ}
}""").lstrip(),
},
"utokyo_pr2_opening_fridge": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@misc{oh2023pr2utokyodatasets,
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
title={X-Embodiment U-Tokyo PR2 Datasets},
year={2023},
url={https://github.com/ojh6404/rlds_dataset_builder},
}""").lstrip(),
},
"utokyo_pr2_tabletop_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@misc{oh2023pr2utokyodatasets,
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
title={X-Embodiment U-Tokyo PR2 Datasets},
year={2023},
url={https://github.com/ojh6404/rlds_dataset_builder},
}""").lstrip(),
},
"utokyo_saytap": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://saytap.github.io/",
"paper": "https://huggingface.co/papers/2306.07580",
"citation_bibtex": dedent(r"""
@article{saytap2023,
author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
Tatsuya Harada},
title = {SayTap: Language to Quadrupedal Locomotion},
eprint = {arXiv:2306.07580},
url = {https://saytap.github.io},
note = {https://saytap.github.io},
year = {2023}
}""").lstrip(),
},
"utokyo_xarm_bimanual": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"citation_bibtex": dedent(r"""
@misc{matsushima2023weblab,
title={Weblab xArm Dataset},
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
year={2023},
}""").lstrip(),
},
"utokyo_xarm_pick_and_place": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"citation_bibtex": dedent(r"""
@misc{matsushima2023weblab,
title={Weblab xArm Dataset},
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
year={2023},
}""").lstrip(),
},
"viola": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/VIOLA/",
"paper": "https://huggingface.co/papers/2210.11339",
"citation_bibtex": dedent(r"""
@article{zhu2022viola,
title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
journal={6th Annual Conference on Robot Learning (CoRL)},
year={2022}
}""").lstrip(),
},
}
# spellchecker:on
def batch_convert():
status = {}
logfile = LOCAL_DIR / "conversion_log.txt"
assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
for num, (name, kwargs) in enumerate(DATASETS.items()):
repo_id = f"lerobot/{name}"
print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
print("---------------------------------------------------------")
try:
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
status = f"{repo_id}: success."
with open(logfile, "a") as file:
file.write(status + "\n")
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
file.write(status + "\n")
continue
if __name__ == "__main__":
batch_convert()

View File

@@ -0,0 +1,687 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
We support 3 different scenarios for these tasks (see instructions below):
1. Single task dataset: all episodes of your dataset have the same single task.
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
one episode to the next.
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
recorded with. For now, only Aloha/Koch type robots are supported with this option.
# 1. Single task dataset
If your dataset contains a single task, you can simply provide it directly via the CLI with the
'--single-task' option.
Examples:
```bash
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
--repo-id lerobot/aloha_sim_insertion_human_image \
--single-task "Insert the peg into the socket." \
--robot-config lerobot/configs/robot/aloha.yaml \
--local-dir data
```
```bash
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
--repo-id aliberts/koch_tutorial \
--single-task "Pick the Lego block and drop it in the box on the right." \
--robot-config lerobot/configs/robot/koch.yaml \
--local-dir data
```
# 2. Single task episodes
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
this column's name with the '--tasks-col' arg.
Example:
```bash
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
'--tasks-path' arg. This file should have the following structure where keys correspond to each
episode_index in the dataset, and values are the language instruction for that episode.
Example:
```json
{
"0": "Do something",
"1": "Do something else",
"2": "Do something",
"3": "Go there",
...
}
```
# 3. Multi task episodes
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
parquet file, and you must provide this column's name with the '--tasks-col' arg.
Example:
```bash
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
"""
import argparse
import contextlib
import filecmp
import json
import logging
import math
import shutil
import subprocess
import tempfile
from pathlib import Path
import datasets
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
from datasets import Dataset
from huggingface_hub import HfApi
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
from safetensors.torch import load_file
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
create_branch,
create_lerobot_dataset_card,
flatten_dict,
get_safe_version,
load_json,
unflatten_dict,
write_json,
write_jsonlines,
)
from lerobot.datasets.video_utils import (
VideoFrame, # noqa: F401
get_image_pixel_channels,
get_video_info,
)
from lerobot.robots import RobotConfig
V16 = "v1.6"
V20 = "v2.0"
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
V1_INFO_PATH = "meta_data/info.json"
V1_STATS_PATH = "meta_data/stats.safetensors"
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
if robot_cfg.type in ["aloha", "koch"]:
state_names = [
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
for arm in robot_cfg.follower_arms
for motor in robot_cfg.follower_arms[arm].motors
]
action_names = [
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
for arm in robot_cfg.leader_arms
for motor in robot_cfg.leader_arms[arm].motors
]
# elif robot_cfg["robot_type"] == "stretch3": TODO
else:
raise NotImplementedError(
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
)
return {
"robot_type": robot_cfg.type,
"names": {
"observation.state": state_names,
"observation.effort": state_names,
"action": action_names,
},
}
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
safetensor_path = v1_dir / V1_STATS_PATH
stats = load_file(safetensor_path)
serialized_stats = {key: value.tolist() for key, value in stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
json_path = v2_dir / STATS_PATH
json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f:
json.dump(serialized_stats, f, indent=4)
# Sanity check
with open(json_path) as f:
stats_json = json.load(f)
stats_json = flatten_dict(stats_json)
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
for key in stats:
torch.testing.assert_close(stats_json[key], stats[key])
def get_features_from_hf_dataset(
dataset: Dataset, robot_config: RobotConfig | None = None
) -> dict[str, list]:
robot_config = parse_robot_config(robot_config)
features = {}
for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value):
dtype = ft.dtype
shape = (1,)
names = None
if isinstance(ft, datasets.Sequence):
assert isinstance(ft.feature, datasets.Value)
dtype = ft.feature.dtype
shape = (ft.length,)
motor_names = (
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
)
assert len(motor_names) == shape[0]
names = {"motors": motor_names}
elif isinstance(ft, datasets.Image):
dtype = "image"
image = dataset[0][key] # Assuming first row
channels = get_image_pixel_channels(image)
shape = (image.height, image.width, channels)
names = ["height", "width", "channels"]
elif ft._type == "VideoFrame":
dtype = "video"
shape = None # Add shape later
names = ["height", "width", "channels"]
features[key] = {
"dtype": dtype,
"shape": shape,
"names": names,
}
return features
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
features = dataset.features
features["task_index"] = datasets.Value(dtype="int64")
dataset = Dataset.from_pandas(df, features=features, split="train")
return dataset, tasks
def add_task_index_from_tasks_col(
dataset: Dataset, tasks_col: str
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
df = dataset.to_pandas()
# HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
# Create task_index col
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
# Build the dataset back from df
features = dataset.features
features["task_index"] = datasets.Value(dtype="int64")
dataset = Dataset.from_pandas(df, features=features, split="train")
dataset = dataset.remove_columns(tasks_col)
return dataset, tasks, tasks_by_episode
def split_parquet_by_episodes(
dataset: Dataset,
total_episodes: int,
total_chunks: int,
output_dir: Path,
) -> list:
table = dataset.data.table
episode_lengths = []
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
episode_chunk=ep_chunk, episode_index=ep_idx
)
pq.write_table(ep_table, output_file)
return episode_lengths
def move_videos(
repo_id: str,
video_keys: list[str],
total_episodes: int,
total_chunks: int,
work_dir: Path,
clean_gittatributes: Path,
branch: str = "main",
) -> None:
"""
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
commands to fetch git lfs video files references to move them into subdirectories without having to
actually download them.
"""
_lfs_clone(repo_id, work_dir, branch)
videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
if len(video_files) == 0:
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys)
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
current_gittatributes = work_dir / ".gitattributes"
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
if lfs_untracked_videos:
fix_lfs_video_files_tracking(work_dir, video_files)
if videos_moved:
return
video_dirs = sorted(work_dir.glob("videos*/"))
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
for vid_key in video_keys:
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk, video_key=vid_key
)
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
for dir in video_dirs:
if (dir / video_file).is_file():
video_path = dir / video_file
break
video_path.rename(work_dir / target_path)
commit_message = "Move video files into chunk subdirectories"
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
"""
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking.
"""
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
commit_message = "Track video files with git lfs"
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run(
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
check=True,
env=env,
)
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files]
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
# Assumes first episode
video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
for vid_key in video_keys
]
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
)
videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
return videos_info_dict
def convert_dataset(
repo_id: str,
local_dir: Path,
single_task: str | None = None,
tasks_path: Path | None = None,
tasks_col: Path | None = None,
robot_config: RobotConfig | None = None,
test_branch: str | None = None,
**card_kwargs,
):
v1 = get_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)
v20_dir.mkdir(parents=True, exist_ok=True)
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
)
branch = "main"
if test_branch:
branch = test_branch
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
features = get_features_from_hf_dataset(dataset, robot_config)
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
if single_task and "language_instruction" in dataset.column_names:
logging.warning(
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
)
single_task = None
tasks_col = "language_instruction"
# Episodes & chunks
episode_indices = sorted(dataset.unique("episode_index"))
total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes))
total_videos = total_episodes * len(video_keys)
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
total_chunks += 1
# Tasks
if single_task:
tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_path:
tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
else:
raise ValueError
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = {
"dtype": "int64",
"shape": (1,),
"names": None,
}
# Videos
if video_keys:
assert metadata_v1.get("video", False)
dataset = dataset.remove_columns(video_keys)
clean_gitattr = Path(
hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
)
).absolute()
with tempfile.TemporaryDirectory() as tmp_video_dir:
move_videos(
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
for key in video_keys:
features[key]["shape"] = (
videos_info[key].pop("video.height"),
videos_info[key].pop("video.width"),
videos_info[key].pop("video.channels"),
)
features[key]["video_info"] = videos_info[key]
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
if "encoding" in metadata_v1:
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
else:
assert metadata_v1.get("video", 0) == 0
videos_info = None
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
if robot_config is not None:
robot_type = robot_config.type
repo_tags = [robot_type]
else:
robot_type = "unknown"
repo_tags = None
# Episodes
episodes = [
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices
]
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
# Assemble metadata v2.0
metadata_v2_0 = {
"codebase_version": V20,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": len(dataset),
"total_tasks": len(tasks),
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"},
"data_path": DEFAULT_PARQUET_PATH,
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
"features": features,
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="data",
folder_path=v20_dir / "data",
repo_type="dataset",
revision=branch,
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="meta",
folder_path=v20_dir / "meta",
repo_type="dataset",
revision=branch,
)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
if not test_branch:
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
if robot_type == "aloha":
raise NotImplementedError # TODO
elif robot_type == "koch_follower":
from lerobot.robots.koch_follower import KochFollowerConfig
return KochFollowerConfig(**kwargs)
elif robot_type == "so100_follower":
from lerobot.robots.so100_follower import SO100FollowerConfig
return SO100FollowerConfig(**kwargs)
elif robot_type == "stretch":
from lerobot.robots.stretch3 import Stretch3RobotConfig
return Stretch3RobotConfig(**kwargs)
elif robot_type == "lekiwi":
from lerobot.robots.lekiwi import LeKiwiConfig
return LeKiwiConfig(**kwargs)
else:
raise ValueError(f"Robot type '{robot_type}' is not available.")
def main():
parser = argparse.ArgumentParser()
task_args = parser.add_mutually_exclusive_group(required=True)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
task_args.add_argument(
"--single-task",
type=str,
help="A short but accurate description of the single task performed in the dataset.",
)
task_args.add_argument(
"--tasks-col",
type=str,
help="The name of the column containing language instructions",
)
task_args.add_argument(
"--tasks-path",
type=Path,
help="The path to a .json file containing one language instruction for each episode_index",
)
parser.add_argument(
"--robot",
type=str,
default=None,
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
)
parser.add_argument(
"--local-dir",
type=Path,
default=None,
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
)
parser.add_argument(
"--license",
type=str,
default="apache-2.0",
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
)
parser.add_argument(
"--test-branch",
type=str,
default=None,
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
)
args = parser.parse_args()
if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2")
if args.robot is not None:
robot_config = make_robot_config(args.robot)
del args.robot
convert_dataset(**vars(args), robot_config=robot_config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,87 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import traceback
from pathlib import Path
from datasets import get_dataset_config_info
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import INFO_PATH, write_info
from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
LOCAL_DIR = Path("data/")
hub_api = HfApi()
def fix_dataset(repo_id: str) -> str:
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
return f"{repo_id}: skipped (not in {V20})."
dataset_info = get_dataset_config_info(repo_id, "default")
with SuppressWarnings():
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
parquet_features = set(dataset_info.features)
diff_parquet_meta = parquet_features - meta_features
diff_meta_parquet = meta_features - parquet_features
if diff_parquet_meta:
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
if not diff_meta_parquet:
return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet:
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root)
commit_info = hub_api.upload_file(
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
path_in_repo=INFO_PATH,
repo_id=repo_id,
repo_type="dataset",
revision=V20,
commit_message="Remove 'language_instruction'",
create_pr=True,
)
return f"{repo_id}: success - PR: {commit_info.pr_url}"
def batch_fix():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "fix_features_v20.txt"
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
status = fix_dataset(repo_id)
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
logging.info(status)
with open(logfile, "a") as file:
file.write(status + "\n")
if __name__ == "__main__":
batch_fix()

View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
"""
import traceback
from pathlib import Path
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
LOCAL_DIR = Path("data/")
def batch_convert():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "conversion_log_v21.txt"
hub_api = HfApi()
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
status = f"{repo_id}: success (already in {V21})."
else:
convert_dataset(repo_id)
status = f"{repo_id}: success."
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
file.write(status + "\n")
if __name__ == "__main__":
batch_convert()

View File

@@ -0,0 +1,114 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
- Check consistency between these new stats and the old ones.
- Remove the deprecated `stats.json`.
- Update codebase_version in `info.json`.
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
Usage:
```bash
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
--repo-id=aliberts/koch_tutorial
```
"""
import argparse
import logging
from huggingface_hub import HfApi
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
V20 = "v2.0"
V21 = "v2.1"
class SuppressWarnings:
def __enter__(self):
self.previous_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.ERROR)
def __exit__(self, exc_type, exc_val, exc_tb):
logging.getLogger().setLevel(self.previous_level)
def convert_dataset(
repo_id: str,
branch: str | None = None,
num_workers: int = 4,
):
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:
(dataset.root / STATS_PATH).unlink()
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",
type=str,
default=None,
help="Repo branch to push your dataset. Defaults to the main branch.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for parallelizing stats compute. Defaults to 4.",
)
args = parser.parse_args()
convert_dataset(**vars(args))

View File

@@ -0,0 +1,99 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from tqdm import tqdm
from lerobot.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import write_episode_stats
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
ep_len = dataset.meta.episodes[episode_index]["length"]
sampled_indices = sample_indices(ep_len)
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
video_frames = dataset._query_videos(query_timestamps, episode_index)
return video_frames[ft_key].numpy()
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
ep_stats = {}
for key, ft in dataset.features.items():
if ft["dtype"] == "video":
# We sample only for videos
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
else:
ep_ft_data = np.array(ep_data[key])
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
dataset.meta.episodes_stats[ep_idx] = ep_stats
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
assert dataset.episodes is None
print("Computing episodes stats")
total_episodes = dataset.meta.total_episodes
if num_workers > 0:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
for ep_idx in range(total_episodes)
}
for future in tqdm(as_completed(futures), total=total_episodes):
future.result()
else:
for ep_idx in tqdm(range(total_episodes)):
convert_episode_stats(dataset, ep_idx)
for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats(
dataset: LeRobotDataset,
reference_stats: dict[str, dict[str, np.ndarray]],
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
):
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
for key, ft in dataset.features.items():
# These values might need some fine-tuning
if ft["dtype"] == "video":
# to account for image sub-sampling
rtol, atol = video_rtol_atol
else:
rtol, atol = default_rtol_atol
for stat, val in agg_stats[key].items():
if key in reference_stats and stat in reference_stats[key]:
err_msg = f"feature='{key}' stats='{stat}'"
np.testing.assert_allclose(
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
)

View File

@@ -0,0 +1,453 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import importlib
import logging
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, ClassVar
import av
import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str | None = None,
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
Args:
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
Returns:
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend is None:
backend = get_safe_default_codec()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video
The backend can be either "pyav" (default) or "video_reader".
"video_reader" requires installing torchvision from source, see:
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
(note that you need to compile against ffmpeg<4.3)
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
For more info on video decoding, see `benchmark/video/README.md`
See torchvision doc for more info on these two backends:
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
video_path = str(video_path)
# set backend
keyframes_only = False
torchvision.set_video_backend(backend)
if backend == "pyav":
keyframes_only = True # pyav doesn't support accurate seek
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision.io.VideoReader(video_path, "video")
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = min(timestamps)
last_ts = max(timestamps)
# access closest key frame of the first requested frame
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts, keyframes_only=keyframes_only)
# load all frames until last requested frame
loaded_frames = []
loaded_ts = []
for frame in reader:
current_ts = frame["pts"]
if log_loaded_timestamps:
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
if backend == "pyav":
reader.container.close()
reader = None
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
return closest_frames
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
# initialize video decoder
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
loaded_frames = []
loaded_ts = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
# compute distances between each query timestamp and loaded timestamps
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
# convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
return closest_frames
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
fast_decode: int = 0,
log_level: int | None = av.logging.ERROR,
overwrite: bool = False,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
# Check encoder availability
if vcodec not in ["h264", "hevc", "libsvtav1"]:
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
video_path.parent.mkdir(parents=True, exist_ok=overwrite)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
logging.warning(
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
)
pix_fmt = "yuv420p"
# Get input frames
template = "frame_" + ("[0-9]" * 6) + ".png"
input_list = sorted(
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("_")[-1].split(".")[0])
)
# Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
dummy_image = Image.open(input_list[0])
width, height = dummy_image.size
# Define video codec options
video_options = {}
if g is not None:
video_options["g"] = str(g)
if crf is not None:
video_options["crf"] = str(crf)
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
video_options[key] = value
# Set logging level
if log_level is not None:
# "While less efficient, it is generally preferable to modify logging with Pythons logging"
logging.getLogger("libav").setLevel(log_level)
# Create and open output file (overwrite by default)
with av.open(str(video_path), "w") as output:
output_stream = output.add_stream(vcodec, fps, options=video_options)
output_stream.pix_fmt = pix_fmt
output_stream.width = width
output_stream.height = height
# Loop through input frames and encode them
for input_data in input_list:
input_image = Image.open(input_data).convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
# Flush the encoder
packet = output_stream.encode()
if packet:
output.mux(packet)
# Reset logging level
if log_level is not None:
av.logging.restore_default_callback()
if not video_path.exists():
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
@dataclass
class VideoFrame:
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
"""
Provides a type for a dataset containing video frames.
Example:
```python
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
features = {"image": VideoFrame()}
Dataset.from_dict(data_dict, features=Features(features))
```
"""
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
_type: str = field(default="VideoFrame", init=False, repr=False)
def __call__(self):
return self.pa_type
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"'register_feature' is experimental and might be subject to breaking changes in the future.",
category=UserWarning,
)
# to make VideoFrame available in HuggingFace `datasets`
register_feature(VideoFrame, "VideoFrame")
def get_audio_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.ERROR)
# Getting audio stream information
audio_info = {}
with av.open(str(video_path), "r") as audio_file:
try:
audio_stream = audio_file.streams.audio[0]
except IndexError:
# Reset logging level
av.logging.restore_default_callback()
return {"has_audio": False}
audio_info["audio.channels"] = audio_stream.channels
audio_info["audio.codec"] = audio_stream.codec.canonical_name
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
audio_info["audio.bit_rate"] = audio_stream.bit_rate
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
# In an ideal loseless case : fixed number of bits per sample.
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
audio_info["audio.bit_depth"] = audio_stream.format.bits
audio_info["audio.channel_layout"] = audio_stream.layout.name
audio_info["has_audio"] = True
# Reset logging level
av.logging.restore_default_callback()
return audio_info
def get_video_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.ERROR)
# Getting video stream information
video_info = {}
with av.open(str(video_path), "r") as video_file:
try:
video_stream = video_file.streams.video[0]
except IndexError:
# Reset logging level
av.logging.restore_default_callback()
return {}
video_info["video.height"] = video_stream.height
video_info["video.width"] = video_stream.width
video_info["video.codec"] = video_stream.codec.canonical_name
video_info["video.pix_fmt"] = video_stream.pix_fmt
video_info["video.is_depth_map"] = False
# Calculate fps from r_frame_rate
video_info["video.fps"] = int(video_stream.base_rate)
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
video_info["video.channels"] = pixel_channels
# Reset logging level
av.logging.restore_default_callback()
# Adding audio stream information
video_info.update(**get_audio_info(video_path))
return video_info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_image_pixel_channels(image: Image):
if image.mode == "L":
return 1 # Grayscale
elif image.mode == "LA":
return 2 # Grayscale + Alpha
elif image.mode == "RGB":
return 3 # RGB
elif image.mode == "RGBA":
return 4 # RGBA
else:
raise ValueError("Unknown format")

View File

@@ -0,0 +1,15 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401

273
src/lerobot/envs/configs.py Normal file
View File

@@ -0,0 +1,273 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import dataclass, field
from typing import Any, Optional
import draccus
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
@dataclass
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
task: str | None = None
fps: int = 30
features: dict[str, PolicyFeature] = field(default_factory=dict)
features_map: dict[str, str] = field(default_factory=dict)
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@property
@abc.abstractmethod
def gym_kwargs(self) -> dict:
raise NotImplementedError()
@EnvConfig.register_subclass("aloha")
@dataclass
class AlohaEnv(EnvConfig):
task: str = "AlohaInsertion-v0"
fps: int = 50
episode_length: int = 400
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_STATE,
"top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"max_episode_steps": self.episode_length,
}
@EnvConfig.register_subclass("pusht")
@dataclass
class PushtEnv(EnvConfig):
task: str = "PushT-v0"
fps: int = 10
episode_length: int = 300
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
visualization_width: int = 384
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_STATE,
"environment_state": OBS_ENV_STATE,
"pixels": OBS_IMAGE,
}
)
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
elif self.obs_type == "environment_state_agent_pos":
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"visualization_width": self.visualization_width,
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}
@EnvConfig.register_subclass("xarm")
@dataclass
class XarmEnv(EnvConfig):
task: str = "XarmLift-v0"
fps: int = 15
episode_length: int = 200
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
visualization_width: int = 384
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_STATE,
"pixels": OBS_IMAGE,
}
)
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"visualization_width": self.visualization_width,
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}
@dataclass
class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments."""
enabled: bool = False
record_dir: str = "videos"
trajectory_name: str = "trajectory"
@dataclass
class EnvTransformConfig:
"""Configuration for environment wrappers."""
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
control_mode: str = "gamepad"
display_cameras: bool = False
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None
resize_size: Optional[tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
use_gripper: bool = True
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: Optional[RobotConfig] = None
teleop: Optional[TeleoperatorConfig] = None
wrapper: Optional[EnvTransformConfig] = None
fps: int = 10
name: str = "real_robot"
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
reward_classifier_pretrained_path: Optional[str] = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
def gym_kwargs(self) -> dict:
return {}
@EnvConfig.register_subclass("hil")
@dataclass
class HILEnvConfig(EnvConfig):
"""Configuration for the HIL environment."""
type: str = "hil"
name: str = "PandaPickCube"
task: str = "PandaPickCubeKeyboard-v0"
use_viewer: bool = True
gripper_penalty: float = 0.0
use_gamepad: bool = True
state_dim: int = 18
action_dim: int = 4
fps: int = 100
episode_length: int = 100
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"observation.image": OBS_IMAGE,
"observation.state": OBS_STATE,
}
)
################# args from hilserlrobotenv
reward_classifier_pretrained_path: Optional[str] = None
robot_config: Optional[RobotConfig] = None
teleop_config: Optional[TeleoperatorConfig] = None
wrapper: Optional[EnvTransformConfig] = None
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
############################
@property
def gym_kwargs(self) -> dict:
return {
"use_viewer": self.use_viewer,
"use_gamepad": self.use_gamepad,
"gripper_penalty": self.gripper_penalty,
}

View File

@@ -0,0 +1,71 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import gymnasium as gym
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
if env_type == "aloha":
return AlohaEnv(**kwargs)
elif env_type == "pusht":
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
elif env_type == "hil":
return HILEnvConfig(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
"""Makes a gym vector environment according to the config.
Args:
cfg (EnvConfig): the config of the environment to instantiate.
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False.
Raises:
ValueError: if n_envs < 1
ModuleNotFoundError: If the requested env package is not installed
Returns:
gym.vector.VectorEnv: The parallelized gym.env instance.
"""
if n_envs < 1:
raise ValueError("`n_envs must be at least 1")
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
gym_handle = f"{package_name}/{cfg.task}"
# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
return env

136
src/lerobot/envs/utils.py Normal file
View File

@@ -0,0 +1,136 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any
import einops
import gymnasium as gym
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.utils.utils import get_channel_first_image_shape
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# map to expected inputs for the policy
return_observations = {}
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else:
imgs = {"observation.image": observations["pixels"]}
for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img)
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
# This is the case for human-in-the-loop RL where there is only one environment.
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
return_observations[imgkey] = img
if "environment_state" in observations:
env_state = torch.from_numpy(observations["environment_state"]).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
return_observations["observation.environment_state"] = env_state
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations["observation.state"] = agent_pos
return return_observations
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies)
policy_features = {}
for key, ft in env_cfg.features.items():
if ft.type is FeatureType.VISUAL:
if len(ft.shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
shape = get_channel_first_image_shape(ft.shape)
feature = PolicyFeature(type=ft.type, shape=shape)
else:
feature = ft
policy_key = env_cfg.features_map[key]
policy_features[policy_key] = feature
return policy_features
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
first_type = type(env.envs[0]) # Get type of first env
return all(type(e) is first_type for e in env.envs) # Fast type check
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
with warnings.catch_warnings():
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
warnings.warn(
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
UserWarning,
stacklevel=2,
)
if not are_all_envs_same_type(env):
warnings.warn(
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
UserWarning,
stacklevel=2,
)
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
"""Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"):
observation["task"] = env.call("task_description")
elif hasattr(env.envs[0], "task"):
observation["task"] = env.call("task")
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
num_envs = observation[list(observation.keys())[0]].shape[0]
observation["task"] = ["" for _ in range(num_envs)]
return observation

43
src/lerobot/errors.py Normal file
View File

@@ -0,0 +1,43 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class DeviceNotConnectedError(ConnectionError):
"""Exception raised when the device is not connected."""
def __init__(self, message="This device is not connected. Try calling `connect()` first."):
self.message = message
super().__init__(self.message)
class DeviceAlreadyConnectedError(ConnectionError):
"""Exception raised when the device is already connected."""
def __init__(
self,
message="This device is already connected. Try not calling `connect()` twice.",
):
self.message = message
super().__init__(self.message)
class InvalidActionError(ValueError):
"""Exception raised when an action is already invalid."""
def __init__(
self,
message="The action is invalid. Check the value follows what it is expected from the action space.",
):
self.message = message
super().__init__(self.message)

315
src/lerobot/find_cameras.py Normal file
View File

@@ -0,0 +1,315 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper to find the camera devices available in your system.
Example:
```shell
python -m lerobot.find_cameras
```
"""
# NOTE(Steven): RealSense can also be identified/opened as OpenCV cameras. If you know the camera is a RealSense, use the `lerobot.find_cameras realsense` flag to avoid confusion.
# NOTE(Steven): macOS cameras sometimes report different FPS at init time, not an issue here as we don't specify FPS when opening the cameras, but the information displayed might not be truthful.
import argparse
import concurrent.futures
import logging
import time
from pathlib import Path
from typing import Any, Dict, List
import numpy as np
from PIL import Image
from lerobot.cameras.configs import ColorMode
from lerobot.cameras.opencv.camera_opencv import OpenCVCamera
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.cameras.realsense.camera_realsense import RealSenseCamera
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig
logger = logging.getLogger(__name__)
def find_all_opencv_cameras() -> List[Dict[str, Any]]:
"""
Finds all available OpenCV cameras plugged into the system.
Returns:
A list of all available OpenCV cameras with their metadata.
"""
all_opencv_cameras_info: List[Dict[str, Any]] = []
logger.info("Searching for OpenCV cameras...")
try:
opencv_cameras = OpenCVCamera.find_cameras()
for cam_info in opencv_cameras:
all_opencv_cameras_info.append(cam_info)
logger.info(f"Found {len(opencv_cameras)} OpenCV cameras.")
except Exception as e:
logger.error(f"Error finding OpenCV cameras: {e}")
return all_opencv_cameras_info
def find_all_realsense_cameras() -> List[Dict[str, Any]]:
"""
Finds all available RealSense cameras plugged into the system.
Returns:
A list of all available RealSense cameras with their metadata.
"""
all_realsense_cameras_info: List[Dict[str, Any]] = []
logger.info("Searching for RealSense cameras...")
try:
realsense_cameras = RealSenseCamera.find_cameras()
for cam_info in realsense_cameras:
all_realsense_cameras_info.append(cam_info)
logger.info(f"Found {len(realsense_cameras)} RealSense cameras.")
except ImportError:
logger.warning("Skipping RealSense camera search: pyrealsense2 library not found or not importable.")
except Exception as e:
logger.error(f"Error finding RealSense cameras: {e}")
return all_realsense_cameras_info
def find_and_print_cameras(camera_type_filter: str | None = None) -> List[Dict[str, Any]]:
"""
Finds available cameras based on an optional filter and prints their information.
Args:
camera_type_filter: Optional string to filter cameras ("realsense" or "opencv").
If None, lists all cameras.
Returns:
A list of all available cameras matching the filter, with their metadata.
"""
all_cameras_info: List[Dict[str, Any]] = []
if camera_type_filter:
camera_type_filter = camera_type_filter.lower()
if camera_type_filter is None or camera_type_filter == "opencv":
all_cameras_info.extend(find_all_opencv_cameras())
if camera_type_filter is None or camera_type_filter == "realsense":
all_cameras_info.extend(find_all_realsense_cameras())
if not all_cameras_info:
if camera_type_filter:
logger.warning(f"No {camera_type_filter} cameras were detected.")
else:
logger.warning("No cameras (OpenCV or RealSense) were detected.")
else:
print("\n--- Detected Cameras ---")
for i, cam_info in enumerate(all_cameras_info):
print(f"Camera #{i}:")
for key, value in cam_info.items():
if key == "default_stream_profile" and isinstance(value, dict):
print(f" {key.replace('_', ' ').capitalize()}:")
for sub_key, sub_value in value.items():
print(f" {sub_key.capitalize()}: {sub_value}")
else:
print(f" {key.replace('_', ' ').capitalize()}: {value}")
print("-" * 20)
return all_cameras_info
def save_image(
img_array: np.ndarray,
camera_identifier: str | int,
images_dir: Path,
camera_type: str,
):
"""
Saves a single image to disk using Pillow. Handles color conversion if necessary.
"""
try:
img = Image.fromarray(img_array, mode="RGB")
safe_identifier = str(camera_identifier).replace("/", "_").replace("\\", "_")
filename_prefix = f"{camera_type.lower()}_{safe_identifier}"
filename = f"{filename_prefix}.png"
path = images_dir / filename
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path))
logger.info(f"Saved image: {path}")
except Exception as e:
logger.error(f"Failed to save image for camera {camera_identifier} (type {camera_type}): {e}")
def create_camera_instance(cam_meta: Dict[str, Any]) -> Dict[str, Any] | None:
"""Create and connect to a camera instance based on metadata."""
cam_type = cam_meta.get("type")
cam_id = cam_meta.get("id")
instance = None
logger.info(f"Preparing {cam_type} ID {cam_id} with default profile")
try:
if cam_type == "OpenCV":
cv_config = OpenCVCameraConfig(
index_or_path=cam_id,
color_mode=ColorMode.RGB,
)
instance = OpenCVCamera(cv_config)
elif cam_type == "RealSense":
rs_config = RealSenseCameraConfig(
serial_number_or_name=cam_id,
color_mode=ColorMode.RGB,
)
instance = RealSenseCamera(rs_config)
else:
logger.warning(f"Unknown camera type: {cam_type} for ID {cam_id}. Skipping.")
return None
if instance:
logger.info(f"Connecting to {cam_type} camera: {cam_id}...")
instance.connect(warmup=False)
return {"instance": instance, "meta": cam_meta}
except Exception as e:
logger.error(f"Failed to connect or configure {cam_type} camera {cam_id}: {e}")
if instance and instance.is_connected:
instance.disconnect()
return None
def process_camera_image(
cam_dict: Dict[str, Any], output_dir: Path, current_time: float
) -> concurrent.futures.Future | None:
"""Capture and process an image from a single camera."""
cam = cam_dict["instance"]
meta = cam_dict["meta"]
cam_type_str = str(meta.get("type", "unknown"))
cam_id_str = str(meta.get("id", "unknown"))
try:
image_data = cam.read()
return save_image(
image_data,
cam_id_str,
output_dir,
cam_type_str,
)
except TimeoutError:
logger.warning(
f"Timeout reading from {cam_type_str} camera {cam_id_str} at time {current_time:.2f}s."
)
except Exception as e:
logger.error(f"Error reading from {cam_type_str} camera {cam_id_str}: {e}")
return None
def cleanup_cameras(cameras_to_use: List[Dict[str, Any]]):
"""Disconnect all cameras."""
logger.info(f"Disconnecting {len(cameras_to_use)} cameras...")
for cam_dict in cameras_to_use:
try:
if cam_dict["instance"] and cam_dict["instance"].is_connected:
cam_dict["instance"].disconnect()
except Exception as e:
logger.error(f"Error disconnecting camera {cam_dict['meta'].get('id')}: {e}")
def save_images_from_all_cameras(
output_dir: Path,
record_time_s: float = 2.0,
camera_type: str | None = None,
):
"""
Connects to detected cameras (optionally filtered by type) and saves images from each.
Uses default stream profiles for width, height, and FPS.
Args:
output_dir: Directory to save images.
record_time_s: Duration in seconds to record images.
camera_type: Optional string to filter cameras ("realsense" or "opencv").
If None, uses all detected cameras.
"""
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving images to {output_dir}")
all_camera_metadata = find_and_print_cameras(camera_type_filter=camera_type)
if not all_camera_metadata:
logger.warning("No cameras detected matching the criteria. Cannot save images.")
return
cameras_to_use = []
for cam_meta in all_camera_metadata:
camera_instance = create_camera_instance(cam_meta)
if camera_instance:
cameras_to_use.append(camera_instance)
if not cameras_to_use:
logger.warning("No cameras could be connected. Aborting image save.")
return
logger.info(f"Starting image capture for {record_time_s} seconds from {len(cameras_to_use)} cameras.")
start_time = time.perf_counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=len(cameras_to_use) * 2) as executor:
try:
while time.perf_counter() - start_time < record_time_s:
futures = []
current_capture_time = time.perf_counter()
for cam_dict in cameras_to_use:
future = process_camera_image(cam_dict, output_dir, current_capture_time)
if future:
futures.append(future)
if futures:
concurrent.futures.wait(futures)
except KeyboardInterrupt:
logger.info("Capture interrupted by user.")
finally:
print("\nFinalizing image saving...")
executor.shutdown(wait=True)
cleanup_cameras(cameras_to_use)
print(f"Image capture finished. Images saved to {output_dir}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Unified camera utility script for listing cameras and capturing images."
)
parser.add_argument(
"camera_type",
type=str,
nargs="?",
default=None,
choices=["realsense", "opencv"],
help="Specify camera type to capture from (e.g., 'realsense', 'opencv'). Captures from all if omitted.",
)
parser.add_argument(
"--output-dir",
type=Path,
default="outputs/captured_images",
help="Directory to save images. Default: outputs/captured_images",
)
parser.add_argument(
"--record-time-s",
type=float,
default=6.0,
help="Time duration to attempt capturing frames. Default: 6 seconds.",
)
args = parser.parse_args()
save_images_from_all_cameras(**vars(args))

65
src/lerobot/find_port.py Normal file
View File

@@ -0,0 +1,65 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper to find the USB port associated with your MotorsBus.
Example:
```shell
python -m lerobot.find_port
```
"""
import platform
import time
from pathlib import Path
def find_available_ports():
from serial.tools import list_ports # Part of pyserial library
if platform.system() == "Windows":
# List COM ports using pyserial
ports = [port.device for port in list_ports.comports()]
else: # Linux/macOS
# List /dev/tty* ports for Unix-based systems
ports = [str(path) for path in Path("/dev").glob("tty*")]
return ports
def find_port():
print("Finding all available ports for the MotorsBus.")
ports_before = find_available_ports()
print("Ports before disconnecting:", ports_before)
print("Remove the USB cable from your MotorsBus and press Enter when done.")
input() # Wait for user to disconnect the device
time.sleep(0.5) # Allow some time for port to be released
ports_after = find_available_ports()
ports_diff = list(set(ports_before) - set(ports_after))
if len(ports_diff) == 1:
port = ports_diff[0]
print(f"The port of this MotorsBus is '{port}'")
print("Reconnect the USB cable.")
elif len(ports_diff) == 0:
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
else:
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
if __name__ == "__main__":
find_port()

View File

@@ -0,0 +1,483 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from numpy.typing import NDArray
from scipy.spatial.transform import Rotation
def skew_symmetric(w: NDArray[np.float32]) -> NDArray[np.float32]:
"""Creates the skew-symmetric matrix from a 3D vector."""
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
def rodrigues_rotation(w: NDArray[np.float32], theta: float) -> NDArray[np.float32]:
"""Computes the rotation matrix using Rodrigues' formula."""
w_hat = skew_symmetric(w)
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
def screw_axis_to_transform(s: NDArray[np.float32], theta: float) -> NDArray[np.float32]:
"""Converts a screw axis to a 4x4 transformation matrix."""
screw_axis_rot = s[:3]
screw_axis_trans = s[3:]
# Pure translation
if np.allclose(screw_axis_rot, 0) and np.linalg.norm(screw_axis_trans) == 1:
transform = np.eye(4)
transform[:3, 3] = screw_axis_trans * theta
# Rotation (and potentially translation)
elif np.linalg.norm(screw_axis_rot) == 1:
w_hat = skew_symmetric(screw_axis_rot)
rot_mat = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
t = (
np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat
) @ screw_axis_trans
transform = np.eye(4)
transform[:3, :3] = rot_mat
transform[:3, 3] = t
else:
raise ValueError("Invalid screw axis parameters")
return transform
def pose_difference_se3(pose1: NDArray[np.float32], pose2: NDArray[np.float32]) -> NDArray[np.float32]:
"""
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space,
combining rotation (SO(3)) and translation.
Each 4x4 matrix has the following structure:
[R11 R12 R13 tx]
[R21 R22 R23 ty]
[R31 R32 R33 tz]
[ 0 0 0 1]
where R is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector.
Args:
pose1: A 4x4 numpy array representing the first pose.
pose2: A 4x4 numpy array representing the second pose.
Returns:
A 6D numpy array concatenating translation and rotation differences.
First 3 elements are the translational difference (position).
Last 3 elements are the rotational difference in axis-angle representation.
"""
rot1 = pose1[:3, :3]
rot2 = pose2[:3, :3]
translation_diff = pose1[:3, 3] - pose2[:3, 3]
# Calculate rotational difference using scipy's Rotation library
rot_diff = Rotation.from_matrix(rot1 @ rot2.T)
rotation_diff = rot_diff.as_rotvec() # Axis-angle representation
return np.concatenate([translation_diff, rotation_diff])
def se3_error(target_pose: NDArray[np.float32], current_pose: NDArray[np.float32]) -> NDArray[np.float32]:
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
rot_target = target_pose[:3, :3]
rot_current = current_pose[:3, :3]
rot_error_mat = rot_target @ rot_current.T
rot_error = Rotation.from_matrix(rot_error_mat).as_rotvec()
return np.concatenate([pos_error, rot_error])
class RobotKinematics:
"""Robot kinematics class supporting multiple robot models."""
# Robot measurements dictionary
ROBOT_MEASUREMENTS = {
"koch": {
"gripper": [0.239, -0.001, 0.024],
"wrist": [0.209, 0, 0.024],
"forearm": [0.108, 0, 0.02],
"humerus": [0, 0, 0.036],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"moss": {
"gripper": [0.246, 0.013, 0.111],
"wrist": [0.245, 0.002, 0.064],
"forearm": [0.122, 0, 0.064],
"humerus": [0.001, 0.001, 0.063],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"so_old_calibration": {
"gripper": [0.320, 0, 0.050],
"wrist": [0.278, 0, 0.050],
"forearm": [0.143, 0, 0.044],
"humerus": [0.031, 0, 0.072],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"so_new_calibration": {
"gripper": [0.33, 0.0, 0.285],
"wrist": [0.30, 0.0, 0.267],
"forearm": [0.25, 0.0, 0.266],
"humerus": [0.06, 0.0, 0.264],
"shoulder": [0.0, 0.0, 0.238],
"base": [0.0, 0.0, 0.12],
},
}
def __init__(self, robot_type: str = "so100"):
"""Initialize kinematics for the specified robot type.
Args:
robot_type: String specifying the robot model ("koch", "so100", or "moss")
"""
if robot_type not in self.ROBOT_MEASUREMENTS:
raise ValueError(
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
)
self.robot_type = robot_type
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
# Initialize all transformation matrices and screw axes
self._setup_transforms()
def _create_translation_matrix(
self, x: float = 0.0, y: float = 0.0, z: float = 0.0
) -> NDArray[np.float32]:
"""Create a 4x4 translation matrix."""
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
def _setup_transforms(self):
"""Setup all transformation matrices and screw axes for the robot."""
# Set up rotation matrices (constant across robot types)
# Gripper orientation
self.gripper_X0 = np.array(
[
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, -1, 0, 0],
[0, 0, 0, 1],
],
dtype=np.float32,
)
# Wrist orientation
self.wrist_X0 = np.array(
[
[0, -1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
],
dtype=np.float32,
)
# Base orientation
self.base_X0 = np.array(
[
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
],
dtype=np.float32,
)
# Gripper
# Screw axis of gripper frame wrt base frame
self.S_BG = np.array(
[
1,
0,
0,
0,
self.measurements["gripper"][2],
-self.measurements["gripper"][1],
],
dtype=np.float32,
)
# Gripper origin to centroid transform
self.X_GoGc = self._create_translation_matrix(x=0.07)
# Gripper origin to tip transform
self.X_GoGt = self._create_translation_matrix(x=0.12)
# 0-position gripper frame pose wrt base
self.X_BoGo = self._create_translation_matrix(
x=self.measurements["gripper"][0],
y=self.measurements["gripper"][1],
z=self.measurements["gripper"][2],
)
# Wrist
# Screw axis of wrist frame wrt base frame
self.S_BR = np.array(
[0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]], dtype=np.float32
)
# 0-position origin to centroid transform
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
# 0-position wrist frame pose wrt base
self.X_BR = self._create_translation_matrix(
x=self.measurements["wrist"][0],
y=self.measurements["wrist"][1],
z=self.measurements["wrist"][2],
)
# Forearm
# Screw axis of forearm frame wrt base frame
self.S_BF = np.array(
[
0,
1,
0,
-self.measurements["forearm"][2],
0,
self.measurements["forearm"][0],
],
dtype=np.float32,
)
# Forearm origin + centroid transform
self.X_ForearmFc = self._create_translation_matrix(x=0.036)
# 0-position forearm frame pose wrt base
self.X_BF = self._create_translation_matrix(
x=self.measurements["forearm"][0],
y=self.measurements["forearm"][1],
z=self.measurements["forearm"][2],
)
# Humerus
# Screw axis of humerus frame wrt base frame
self.S_BH = np.array(
[
0,
-1,
0,
self.measurements["humerus"][2],
0,
-self.measurements["humerus"][0],
],
dtype=np.float32,
)
# Humerus origin to centroid transform
self.X_HoHc = self._create_translation_matrix(x=0.0475)
# 0-position humerus frame pose wrt base
self.X_BH = self._create_translation_matrix(
x=self.measurements["humerus"][0],
y=self.measurements["humerus"][1],
z=self.measurements["humerus"][2],
)
# Shoulder
# Screw axis of shoulder frame wrt Base frame
self.S_BS = np.array([0, 0, -1, 0, 0, 0], dtype=np.float32)
# Shoulder origin to centroid transform
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
# 0-position shoulder frame pose wrt base
self.X_BS = self._create_translation_matrix(
x=self.measurements["shoulder"][0],
y=self.measurements["shoulder"][1],
z=self.measurements["shoulder"][2],
)
# Base
# Base origin to centroid transform
self.X_BoBc = self._create_translation_matrix(y=0.015)
# World to base transform
self.X_WoBo = self._create_translation_matrix(
x=self.measurements["base"][0],
y=self.measurements["base"][1],
z=self.measurements["base"][2],
)
# Pre-compute gripper post-multiplication matrix
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
def forward_kinematics(
self,
robot_pos_deg: NDArray[np.float32],
frame: str = "gripper_tip",
) -> NDArray[np.float32]:
"""Generic forward kinematics.
Args:
robot_pos_deg: Joint positions in degrees. Can be ``None`` when
computing the *base* frame as it does not depend on joint
angles.
frame: Target frame. One of
``{"base", "shoulder", "humerus", "forearm", "wrist", "gripper", "gripper_tip"}``.
Returns
-------
NDArray[np.float32]
4×4 homogeneous transformation matrix of the requested frame
expressed in the world coordinate system.
"""
frame = frame.lower()
if frame not in {
"base",
"shoulder",
"humerus",
"forearm",
"wrist",
"gripper",
"gripper_tip",
}:
raise ValueError(
f"Unknown frame '{frame}'. Valid options are base, shoulder, humerus, forearm, wrist, gripper, gripper_tip."
)
# Base frame does not rely on joint angles.
if frame == "base":
return self.X_WoBo @ self.X_BoBc @ self.base_X0
robot_pos_rad = robot_pos_deg / 180 * np.pi
# Extract joint angles (note the sign convention for shoulder lift).
theta_shoulder_pan = robot_pos_rad[0]
theta_shoulder_lift = -robot_pos_rad[1]
theta_elbow_flex = robot_pos_rad[2]
theta_wrist_flex = robot_pos_rad[3]
theta_wrist_roll = robot_pos_rad[4]
# Start with the world-to-base transform; incrementally add successive links.
transformation_matrix = self.X_WoBo @ screw_axis_to_transform(self.S_BS, theta_shoulder_pan)
if frame == "shoulder":
return transformation_matrix @ self.X_SoSc @ self.X_BS
transformation_matrix = transformation_matrix @ screw_axis_to_transform(
self.S_BH, theta_shoulder_lift
)
if frame == "humerus":
return transformation_matrix @ self.X_HoHc @ self.X_BH
transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BF, theta_elbow_flex)
if frame == "forearm":
return transformation_matrix @ self.X_ForearmFc @ self.X_BF
transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BR, theta_wrist_flex)
if frame == "wrist":
return transformation_matrix @ self.X_RoRc @ self.X_BR @ self.wrist_X0
transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BG, theta_wrist_roll)
if frame == "gripper":
return transformation_matrix @ self._fk_gripper_post
else: # frame == "gripper_tip"
return transformation_matrix @ self.X_GoGt @ self.X_BoGo @ self.gripper_X0
def compute_jacobian(
self, robot_pos_deg: NDArray[np.float32], frame: str = "gripper_tip"
) -> NDArray[np.float32]:
"""Finite differences to compute the Jacobian.
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
eps = 1e-8
jac = np.zeros(shape=(6, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
sdot = (
pose_difference_se3(
self.forward_kinematics(robot_pos_deg[:-1] + delta, frame),
self.forward_kinematics(robot_pos_deg[:-1] - delta, frame),
)
/ eps
)
jac[:, el_ix] = sdot
return jac
def compute_positional_jacobian(
self, robot_pos_deg: NDArray[np.float32], frame: str = "gripper_tip"
) -> NDArray[np.float32]:
"""Finite differences to compute the positional Jacobian.
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
eps = 1e-8
jac = np.zeros(shape=(3, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
sdot = (
self.forward_kinematics(robot_pos_deg[:-1] + delta, frame)[:3, 3]
- self.forward_kinematics(robot_pos_deg[:-1] - delta, frame)[:3, 3]
) / eps
jac[:, el_ix] = sdot
return jac
def ik(
self,
current_joint_pos: NDArray[np.float32],
desired_ee_pose: NDArray[np.float32],
position_only: bool = True,
frame: str = "gripper_tip",
max_iterations: int = 5,
learning_rate: float = 1,
) -> NDArray[np.float32]:
"""Inverse kinematics using gradient descent.
Args:
current_joint_state: Initial joint positions in degrees
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
position_only: If True, only match end-effector position, not orientation
frame: Target frame. One of
``{"base", "shoulder", "humerus", "forearm", "wrist", "gripper", "gripper_tip"}``.
max_iterations: Maximum number of iterations to run
learning_rate: Learning rate for gradient descent
Returns:
Joint positions in degrees that achieve the desired end-effector pose
"""
# Do gradient descent.
current_joint_state = current_joint_pos.copy()
for _ in range(max_iterations):
current_ee_pose = self.forward_kinematics(current_joint_state, frame)
if not position_only:
error = se3_error(desired_ee_pose, current_ee_pose)
jac = self.compute_jacobian(current_joint_state, frame)
else:
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
jac = self.compute_positional_jacobian(current_joint_state, frame)
delta_angles = np.linalg.pinv(jac) @ error
current_joint_state[:-1] += learning_rate * delta_angles
if np.linalg.norm(error) < 5e-3:
return current_joint_state
return current_joint_state

View File

@@ -0,0 +1 @@
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus

View File

@@ -0,0 +1,2 @@
from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode
from .tables import *

View File

@@ -0,0 +1,263 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(aliberts): Should we implement FastSyncRead/Write?
# https://github.com/ROBOTIS-GIT/DynamixelSDK/pull/643
# https://github.com/ROBOTIS-GIT/DynamixelSDK/releases/tag/3.8.2
# https://emanual.robotis.com/docs/en/dxl/protocol2/#fast-sync-read-0x8a
# -> Need to check compatibility across models
import logging
from copy import deepcopy
from enum import Enum
from lerobot.utils.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
AVAILABLE_BAUDRATES,
MODEL_BAUDRATE_TABLE,
MODEL_CONTROL_TABLE,
MODEL_ENCODING_TABLE,
MODEL_NUMBER_TABLE,
MODEL_RESOLUTION,
)
PROTOCOL_VERSION = 2.0
DEFAULT_BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
logger = logging.getLogger(__name__)
class OperatingMode(Enum):
# DYNAMIXEL only controls current(torque) regardless of speed and position. This mode is ideal for a
# gripper or a system that only uses current(torque) control or a system that has additional
# velocity/position controllers.
CURRENT = 0
# This mode controls velocity. This mode is identical to the Wheel Mode(endless) from existing DYNAMIXEL.
# This mode is ideal for wheel-type robots.
VELOCITY = 1
# This mode controls position. This mode is identical to the Joint Mode from existing DYNAMIXEL. Operating
# position range is limited by the Max Position Limit(48) and the Min Position Limit(52). This mode is
# ideal for articulated robots that each joint rotates less than 360 degrees.
POSITION = 3
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
# conveyer systems or a system that requires an additional reduction gear. Note that Max Position
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
EXTENDED_POSITION = 4
# This mode controls both position and current(torque). Up to 512 turns are supported (-256[rev] ~
# 256[rev]). This mode is ideal for a system that requires both position and current control such as
# articulated robots or grippers.
CURRENT_POSITION = 5
# This mode directly controls PWM output. (Voltage Control Mode)
PWM = 16
class DriveMode(Enum):
NON_INVERTED = 0
INVERTED = 1
class TorqueMode(Enum):
ENABLED = 1
DISABLED = 0
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
import dynamixel_sdk as dxl
if length == 1:
data = [value]
elif length == 2:
data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)]
elif length == 4:
data = [
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
]
return data
class DynamixelMotorsBus(MotorsBus):
"""
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
the motors. For more info, see the Dynamixel SDK Documentation:
https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20
"""
apply_drive_mode = False
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
default_baudrate = DEFAULT_BAUDRATE
default_timeout = DEFAULT_TIMEOUT_MS
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
model_encoding_table = deepcopy(MODEL_ENCODING_TABLE)
model_number_table = deepcopy(MODEL_NUMBER_TABLE)
model_resolution_table = deepcopy(MODEL_RESOLUTION)
normalized_data = deepcopy(NORMALIZED_DATA)
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
super().__init__(port, motors, calibration)
import dynamixel_sdk as dxl
self.port_handler = dxl.PortHandler(self.port)
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
self.sync_writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
self._comm_success = dxl.COMM_SUCCESS
self._no_error = 0x00
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
pass
def _handshake(self) -> None:
self._assert_motors_exist()
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
model = self.motors[motor].model
search_baudrates = (
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
)
for baudrate in search_baudrates:
self.set_baudrate(baudrate)
id_model = self.broadcast_ping()
if id_model:
found_id, found_model = next(iter(id_model.items()))
expected_model_nb = self.model_number_table[model]
if found_model != expected_model_nb:
raise RuntimeError(
f"Found one motor on {baudrate=} with id={found_id} but it has a "
f"model number '{found_model}' different than the one expected: '{expected_model_nb}'. "
f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')."
)
return baudrate, found_id
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
def configure_motors(self) -> None:
# By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
for motor in self.motors:
self.write("Return_Delay_Time", motor, 0)
@property
def is_calibrated(self) -> bool:
return self.calibration == self.read_calibration()
def read_calibration(self) -> dict[str, MotorCalibration]:
offsets = self.sync_read("Homing_Offset", normalize=False)
mins = self.sync_read("Min_Position_Limit", normalize=False)
maxes = self.sync_read("Max_Position_Limit", normalize=False)
drive_modes = self.sync_read("Drive_Mode", normalize=False)
calibration = {}
for motor, m in self.motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=drive_modes[motor],
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
)
return calibration
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
for motor, calibration in calibration_dict.items():
self.write("Homing_Offset", motor, calibration.homing_offset)
self.write("Min_Position_Limit", motor, calibration.range_min)
self.write("Max_Position_Limit", motor, calibration.range_max)
self.calibration = calibration_dict
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
model = self._id_to_model(id_)
encoding_table = self.model_encoding_table.get(model)
if encoding_table and data_name in encoding_table:
n_bytes = encoding_table[data_name]
ids_values[id_] = encode_twos_complement(ids_values[id_], n_bytes)
return ids_values
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
model = self._id_to_model(id_)
encoding_table = self.model_encoding_table.get(model)
if encoding_table and data_name in encoding_table:
n_bytes = encoding_table[data_name]
ids_values[id_] = decode_twos_complement(ids_values[id_], n_bytes)
return ids_values
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
"""
On Dynamixel Motors:
Present_Position = Actual_Position + Homing_Offset
"""
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
half_turn_homings[motor] = int(max_res / 2) - pos
return half_turn_homings
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
return _split_into_byte_chunks(value, length)
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
for n_try in range(1 + num_retry):
data_list, comm = self.packet_handler.broadcastPing(self.port_handler)
if self._is_comm_success(comm):
break
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
logger.debug(self.packet_handler.getTxRxResult(comm))
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return {id_: data[0] for id_, data in data_list.items()}

View File

@@ -0,0 +1,197 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(Steven): Consider doing the following:
# from enum import Enum
# class MyControlTableKey(Enum):
# ID = "ID"
# GOAL_SPEED = "Goal_Speed"
# ...
#
# MY_CONTROL_TABLE ={
# MyControlTableKey.ID.value: (5,1)
# MyControlTableKey.GOAL_SPEED.value: (46, 2)
# ...
# }
# This allows me do to:
# bus.write(MyControlTableKey.GOAL_SPEED, ...)
# Instead of:
# bus.write("Goal_Speed", ...)
# This is important for two reasons:
# 1. The linter will tell me if I'm trying to use an invalid key, instead of me realizing when I get the RunTimeError
# 2. We can change the value of the MyControlTableKey enums without impacting the client code
# {data_name: (address, size_byte)}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table
X_SERIES_CONTROL_TABLE = {
"Model_Number": (0, 2),
"Model_Information": (2, 4),
"Firmware_Version": (6, 1),
"ID": (7, 1),
"Baud_Rate": (8, 1),
"Return_Delay_Time": (9, 1),
"Drive_Mode": (10, 1),
"Operating_Mode": (11, 1),
"Secondary_ID": (12, 1),
"Protocol_Type": (13, 1),
"Homing_Offset": (20, 4),
"Moving_Threshold": (24, 4),
"Temperature_Limit": (31, 1),
"Max_Voltage_Limit": (32, 2),
"Min_Voltage_Limit": (34, 2),
"PWM_Limit": (36, 2),
"Current_Limit": (38, 2),
"Acceleration_Limit": (40, 4),
"Velocity_Limit": (44, 4),
"Max_Position_Limit": (48, 4),
"Min_Position_Limit": (52, 4),
"Shutdown": (63, 1),
"Torque_Enable": (64, 1),
"LED": (65, 1),
"Status_Return_Level": (68, 1),
"Registered_Instruction": (69, 1),
"Hardware_Error_Status": (70, 1),
"Velocity_I_Gain": (76, 2),
"Velocity_P_Gain": (78, 2),
"Position_D_Gain": (80, 2),
"Position_I_Gain": (82, 2),
"Position_P_Gain": (84, 2),
"Feedforward_2nd_Gain": (88, 2),
"Feedforward_1st_Gain": (90, 2),
"Bus_Watchdog": (98, 1),
"Goal_PWM": (100, 2),
"Goal_Current": (102, 2),
"Goal_Velocity": (104, 4),
"Profile_Acceleration": (108, 4),
"Profile_Velocity": (112, 4),
"Goal_Position": (116, 4),
"Realtime_Tick": (120, 2),
"Moving": (122, 1),
"Moving_Status": (123, 1),
"Present_PWM": (124, 2),
"Present_Current": (126, 2),
"Present_Velocity": (128, 4),
"Present_Position": (132, 4),
"Velocity_Trajectory": (136, 4),
"Position_Trajectory": (140, 4),
"Present_Input_Voltage": (144, 2),
"Present_Temperature": (146, 1),
}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#baud-rate8
X_SERIES_BAUDRATE_TABLE = {
9_600: 0,
57_600: 1,
115_200: 2,
1_000_000: 3,
2_000_000: 4,
3_000_000: 5,
4_000_000: 6,
}
# {data_name: size_byte}
X_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": X_SERIES_CONTROL_TABLE["Homing_Offset"][1],
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],
}
MODEL_ENCODING_TABLE = {
"x_series": X_SERIES_ENCODINGS_TABLE,
"xl330-m077": X_SERIES_ENCODINGS_TABLE,
"xl330-m288": X_SERIES_ENCODINGS_TABLE,
"xl430-w250": X_SERIES_ENCODINGS_TABLE,
"xm430-w350": X_SERIES_ENCODINGS_TABLE,
"xm540-w270": X_SERIES_ENCODINGS_TABLE,
"xc430-w150": X_SERIES_ENCODINGS_TABLE,
}
# {model: model_resolution}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#specifications
MODEL_RESOLUTION = {
"x_series": 4096,
"xl330-m077": 4096,
"xl330-m288": 4096,
"xl430-w250": 4096,
"xm430-w350": 4096,
"xm540-w270": 4096,
"xc430-w150": 4096,
}
# {model: model_number}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table-of-eeprom-area
MODEL_NUMBER_TABLE = {
"xl330-m077": 1190,
"xl330-m288": 1200,
"xl430-w250": 1060,
"xm430-w350": 1020,
"xm540-w270": 1120,
"xc430-w150": 1070,
}
# {model: available_operating_modes}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#operating-mode11
MODEL_OPERATING_MODES = {
"xl330-m077": [0, 1, 3, 4, 5, 16],
"xl330-m288": [0, 1, 3, 4, 5, 16],
"xl430-w250": [1, 3, 4, 16],
"xm430-w350": [0, 1, 3, 4, 5, 16],
"xm540-w270": [0, 1, 3, 4, 5, 16],
"xc430-w150": [1, 3, 4, 16],
}
MODEL_CONTROL_TABLE = {
"x_series": X_SERIES_CONTROL_TABLE,
"xl330-m077": X_SERIES_CONTROL_TABLE,
"xl330-m288": X_SERIES_CONTROL_TABLE,
"xl430-w250": X_SERIES_CONTROL_TABLE,
"xm430-w350": X_SERIES_CONTROL_TABLE,
"xm540-w270": X_SERIES_CONTROL_TABLE,
"xc430-w150": X_SERIES_CONTROL_TABLE,
}
MODEL_BAUDRATE_TABLE = {
"x_series": X_SERIES_BAUDRATE_TABLE,
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
"xc430-w150": X_SERIES_BAUDRATE_TABLE,
}
AVAILABLE_BAUDRATES = [
9_600,
19_200,
38_400,
57_600,
115_200,
230_400,
460_800,
500_000,
576_000,
921_600,
1_000_000,
1_152_000,
2_000_000,
2_500_000,
3_000_000,
3_500_000,
4_000_000,
]

View File

@@ -0,0 +1,2 @@
from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode
from .tables import *

View File

@@ -0,0 +1,454 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from copy import deepcopy
from enum import Enum
from pprint import pformat
from lerobot.utils.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
FIRMWARE_MAJOR_VERSION,
FIRMWARE_MINOR_VERSION,
MODEL_BAUDRATE_TABLE,
MODEL_CONTROL_TABLE,
MODEL_ENCODING_TABLE,
MODEL_NUMBER,
MODEL_NUMBER_TABLE,
MODEL_PROTOCOL,
MODEL_RESOLUTION,
SCAN_BAUDRATES,
)
DEFAULT_PROTOCOL_VERSION = 0
DEFAULT_BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
logger = logging.getLogger(__name__)
class OperatingMode(Enum):
# position servo mode
POSITION = 0
# The motor is in constant speed mode, which is controlled by parameter 0x2e, and the highest bit 15 is
# the direction bit
VELOCITY = 1
# PWM open-loop speed regulation mode, with parameter 0x2c running time parameter control, bit11 as
# direction bit
PWM = 2
# In step servo mode, the number of step progress is represented by parameter 0x2a, and the highest bit 15
# is the direction bit
STEP = 3
class DriveMode(Enum):
NON_INVERTED = 0
INVERTED = 1
class TorqueMode(Enum):
ENABLED = 1
DISABLED = 0
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
import scservo_sdk as scs
if length == 1:
data = [value]
elif length == 2:
data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)]
elif length == 4:
data = [
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
]
return data
def patch_setPacketTimeout(self, packet_length): # noqa: N802
"""
HACK: This patches the PortHandler behavior to set the correct packet timeouts.
It fixes https://gitee.com/ftservo/SCServoSDK/issues/IBY2S6
The bug is fixed on the official Feetech SDK repo (https://gitee.com/ftservo/FTServo_Python)
but because that version is not published on PyPI, we rely on the (unofficial) on that is, which needs
patching.
"""
self.packet_start_time = self.getCurrentTime()
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
class FeetechMotorsBus(MotorsBus):
"""
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
"""
apply_drive_mode = True
available_baudrates = deepcopy(SCAN_BAUDRATES)
default_baudrate = DEFAULT_BAUDRATE
default_timeout = DEFAULT_TIMEOUT_MS
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
model_encoding_table = deepcopy(MODEL_ENCODING_TABLE)
model_number_table = deepcopy(MODEL_NUMBER_TABLE)
model_resolution_table = deepcopy(MODEL_RESOLUTION)
normalized_data = deepcopy(NORMALIZED_DATA)
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
):
super().__init__(port, motors, calibration)
self.protocol_version = protocol_version
self._assert_same_protocol()
import scservo_sdk as scs
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
self.port_handler, scs.PortHandler
)
self.packet_handler = scs.PacketHandler(protocol_version)
self.sync_reader = scs.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
self.sync_writer = scs.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
self._comm_success = scs.COMM_SUCCESS
self._no_error = 0x00
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
raise ValueError(f"Some motors are incompatible with protocol_version={self.protocol_version}")
def _assert_same_protocol(self) -> None:
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
raise RuntimeError("Some motors use an incompatible protocol.")
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
if instruction_name == "sync_read" and self.protocol_version == 1:
raise NotImplementedError(
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' sequentially instead."
)
if instruction_name == "broadcast_ping" and self.protocol_version == 1:
raise NotImplementedError(
"'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead."
)
def _assert_same_firmware(self) -> None:
firmware_versions = self._read_firmware_version(self.ids, raise_on_error=True)
if len(set(firmware_versions.values())) != 1:
raise RuntimeError(
"Some Motors use different firmware versions:"
f"\n{pformat(firmware_versions)}\n"
"Update their firmware first using Feetech's software. "
"Visit https://www.feetechrc.com/software."
)
def _handshake(self) -> None:
self._assert_motors_exist()
self._assert_same_firmware()
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
if self.protocol_version == 0:
return self._find_single_motor_p0(motor, initial_baudrate)
else:
return self._find_single_motor_p1(motor, initial_baudrate)
def _find_single_motor_p0(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
model = self.motors[motor].model
search_baudrates = (
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
)
expected_model_nb = self.model_number_table[model]
for baudrate in search_baudrates:
self.set_baudrate(baudrate)
id_model = self.broadcast_ping()
if id_model:
found_id, found_model = next(iter(id_model.items()))
if found_model != expected_model_nb:
raise RuntimeError(
f"Found one motor on {baudrate=} with id={found_id} but it has a "
f"model number '{found_model}' different than the one expected: '{expected_model_nb}'. "
f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')."
)
return baudrate, found_id
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
import scservo_sdk as scs
model = self.motors[motor].model
search_baudrates = (
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
)
expected_model_nb = self.model_number_table[model]
for baudrate in search_baudrates:
self.set_baudrate(baudrate)
for id_ in range(scs.MAX_ID + 1):
found_model = self.ping(id_)
if found_model is not None:
if found_model != expected_model_nb:
raise RuntimeError(
f"Found one motor on {baudrate=} with id={id_} but it has a "
f"model number '{found_model}' different than the one expected: '{expected_model_nb}'. "
f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')."
)
return baudrate, id_
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
def configure_motors(self) -> None:
for motor in self.motors:
# By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
self.write("Return_Delay_Time", motor, 0)
# Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors.
# Note: this address is not in the official STS3215 Memory Table
self.write("Maximum_Acceleration", motor, 254)
self.write("Acceleration", motor, 254)
@property
def is_calibrated(self) -> bool:
motors_calibration = self.read_calibration()
if set(motors_calibration) != set(self.calibration):
return False
same_ranges = all(
self.calibration[motor].range_min == cal.range_min
and self.calibration[motor].range_max == cal.range_max
for motor, cal in motors_calibration.items()
)
if self.protocol_version == 1:
return same_ranges
same_offsets = all(
self.calibration[motor].homing_offset == cal.homing_offset
for motor, cal in motors_calibration.items()
)
return same_ranges and same_offsets
def read_calibration(self) -> dict[str, MotorCalibration]:
offsets, mins, maxes = {}, {}, {}
for motor in self.motors:
mins[motor] = self.read("Min_Position_Limit", motor, normalize=False)
maxes[motor] = self.read("Max_Position_Limit", motor, normalize=False)
offsets[motor] = (
self.read("Homing_Offset", motor, normalize=False) if self.protocol_version == 0 else 0
)
calibration = {}
for motor, m in self.motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
)
return calibration
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
for motor, calibration in calibration_dict.items():
if self.protocol_version == 0:
self.write("Homing_Offset", motor, calibration.homing_offset)
self.write("Min_Position_Limit", motor, calibration.range_min)
self.write("Max_Position_Limit", motor, calibration.range_max)
self.calibration = calibration_dict
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
"""
On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset
"""
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
half_turn_homings[motor] = pos - int(max_res / 2)
return half_turn_homings
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor_id, 0, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", motor, 1, num_retry=num_retry)
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
model = self._id_to_model(id_)
encoding_table = self.model_encoding_table.get(model)
if encoding_table and data_name in encoding_table:
sign_bit = encoding_table[data_name]
ids_values[id_] = encode_sign_magnitude(ids_values[id_], sign_bit)
return ids_values
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
model = self._id_to_model(id_)
encoding_table = self.model_encoding_table.get(model)
if encoding_table and data_name in encoding_table:
sign_bit = encoding_table[data_name]
ids_values[id_] = decode_sign_magnitude(ids_values[id_], sign_bit)
return ids_values
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
return _split_into_byte_chunks(value, length)
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list = {}
status_length = 6
rx_length = 0
wait_length = status_length * scs.MAX_ID
txpacket = [0] * 6
tx_time_per_byte = (1000.0 / self.port_handler.getBaudRate()) * 10.0
txpacket[scs.PKT_ID] = scs.BROADCAST_ID
txpacket[scs.PKT_LENGTH] = 2
txpacket[scs.PKT_INSTRUCTION] = scs.INST_PING
result = self.packet_handler.txPacket(self.port_handler, txpacket)
if result != scs.COMM_SUCCESS:
self.port_handler.is_using = False
return data_list, result
# set rx timeout
self.port_handler.setPacketTimeoutMillis((wait_length * tx_time_per_byte) + (3.0 * scs.MAX_ID) + 16.0)
rxpacket = []
while not self.port_handler.isPacketTimeout() and rx_length < wait_length:
rxpacket += self.port_handler.readPort(wait_length - rx_length)
rx_length = len(rxpacket)
self.port_handler.is_using = False
if rx_length == 0:
return data_list, scs.COMM_RX_TIMEOUT
while True:
if rx_length < status_length:
return data_list, scs.COMM_RX_CORRUPT
# find packet header
for idx in range(0, (rx_length - 1)):
if (rxpacket[idx] == 0xFF) and (rxpacket[idx + 1] == 0xFF):
break
if idx == 0: # found at the beginning of the packet
# calculate checksum
checksum = 0
for idx in range(2, status_length - 1): # except header & checksum
checksum += rxpacket[idx]
checksum = ~checksum & 0xFF
if rxpacket[status_length - 1] == checksum:
result = scs.COMM_SUCCESS
data_list[rxpacket[scs.PKT_ID]] = rxpacket[scs.PKT_ERROR]
del rxpacket[0:status_length]
rx_length = rx_length - status_length
if rx_length == 0:
return data_list, result
else:
result = scs.COMM_RX_CORRUPT
# remove header (0xFF 0xFF)
del rxpacket[0:2]
rx_length = rx_length - 2
else:
# remove unnecessary packets
del rxpacket[0:idx]
rx_length = rx_length - idx
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
self._assert_protocol_is_compatible("broadcast_ping")
for n_try in range(1 + num_retry):
ids_status, comm = self._broadcast_ping()
if self._is_comm_success(comm):
break
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
logger.debug(self.packet_handler.getTxRxResult(comm))
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:
display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()}
logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}")
return self._read_model_number(list(ids_status), raise_on_error)
def _read_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]:
firmware_versions = {}
for id_ in motor_ids:
firm_ver_major, comm, error = self._read(
*FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error
)
if not self._is_comm_success(comm) or self._is_error(error):
continue
firm_ver_minor, comm, error = self._read(
*FIRMWARE_MINOR_VERSION, id_, raise_on_error=raise_on_error
)
if not self._is_comm_success(comm) or self._is_error(error):
continue
firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}"
return firmware_versions
def _read_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
model_numbers = {}
for id_ in motor_ids:
model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error)
if not self._is_comm_success(comm) or self._is_error(error):
continue
model_numbers[id_] = model_nb
return model_numbers

View File

@@ -0,0 +1,252 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FIRMWARE_MAJOR_VERSION = (0, 1)
FIRMWARE_MINOR_VERSION = (1, 1)
MODEL_NUMBER = (3, 2)
# TODO(Steven): Consider doing the following:
# from enum import Enum
# class MyControlTableKey(Enum):
# ID = "ID"
# GOAL_SPEED = "Goal_Speed"
# ...
#
# MY_CONTROL_TABLE ={
# MyControlTableKey.ID.value: (5,1)
# MyControlTableKey.GOAL_SPEED.value: (46, 2)
# ...
# }
# This allows me do to:
# bus.write(MyControlTableKey.GOAL_SPEED, ...)
# Instead of:
# bus.write("Goal_Speed", ...)
# This is important for two reasons:
# 1. The linter will tell me if I'm trying to use an invalid key, instead of me realizing when I get the RunTimeError
# 2. We can change the value of the MyControlTableKey enums without impacting the client code
# data_name: (address, size_byte)
# http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SMS-STS-emanual-229f4476422d4059abfb1cb0
STS_SMS_SERIES_CONTROL_TABLE = {
# EPROM
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
"Model_Number": MODEL_NUMBER, # read-only
"ID": (5, 1),
"Baud_Rate": (6, 1),
"Return_Delay_Time": (7, 1),
"Response_Status_Level": (8, 1),
"Min_Position_Limit": (9, 2),
"Max_Position_Limit": (11, 2),
"Max_Temperature_Limit": (13, 1),
"Max_Voltage_Limit": (14, 1),
"Min_Voltage_Limit": (15, 1),
"Max_Torque_Limit": (16, 2),
"Phase": (18, 1),
"Unloading_Condition": (19, 1),
"LED_Alarm_Condition": (20, 1),
"P_Coefficient": (21, 1),
"D_Coefficient": (22, 1),
"I_Coefficient": (23, 1),
"Minimum_Startup_Force": (24, 2),
"CW_Dead_Zone": (26, 1),
"CCW_Dead_Zone": (27, 1),
"Protection_Current": (28, 2),
"Angular_Resolution": (30, 1),
"Homing_Offset": (31, 2),
"Operating_Mode": (33, 1),
"Protective_Torque": (34, 1),
"Protection_Time": (35, 1),
"Overload_Torque": (36, 1),
"Velocity_closed_loop_P_proportional_coefficient": (37, 1),
"Over_Current_Protection_Time": (38, 1),
"Velocity_closed_loop_I_integral_coefficient": (39, 1),
# SRAM
"Torque_Enable": (40, 1),
"Acceleration": (41, 1),
"Goal_Position": (42, 2),
"Goal_Time": (44, 2),
"Goal_Velocity": (46, 2),
"Torque_Limit": (48, 2),
"Lock": (55, 1),
"Present_Position": (56, 2), # read-only
"Present_Velocity": (58, 2), # read-only
"Present_Load": (60, 2), # read-only
"Present_Voltage": (62, 1), # read-only
"Present_Temperature": (63, 1), # read-only
"Status": (65, 1), # read-only
"Moving": (66, 1), # read-only
"Present_Current": (69, 2), # read-only
"Goal_Position_2": (71, 2), # read-only
# Factory
"Moving_Velocity": (80, 1),
"Moving_Velocity_Threshold": (80, 1),
"DTs": (81, 1), # (ms)
"Velocity_Unit_factor": (82, 1),
"Hts": (83, 1), # (ns) valid for firmware >= 2.54, other versions keep 0
"Maximum_Velocity_Limit": (84, 1),
"Maximum_Acceleration": (85, 1),
"Acceleration_Multiplier ": (86, 1), # Acceleration multiplier in effect when acceleration is 0
}
# http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SCSCL-emanual-cbcc8ab2e3384282a01d4bf3
SCS_SERIES_CONTROL_TABLE = {
# EPROM
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
"Model_Number": MODEL_NUMBER, # read-only
"ID": (5, 1),
"Baud_Rate": (6, 1),
"Return_Delay_Time": (7, 1),
"Response_Status_Level": (8, 1),
"Min_Position_Limit": (9, 2),
"Max_Position_Limit": (11, 2),
"Max_Temperature_Limit": (13, 1),
"Max_Voltage_Limit": (14, 1),
"Min_Voltage_Limit": (15, 1),
"Max_Torque_Limit": (16, 2),
"Phase": (18, 1),
"Unloading_Condition": (19, 1),
"LED_Alarm_Condition": (20, 1),
"P_Coefficient": (21, 1),
"D_Coefficient": (22, 1),
"I_Coefficient": (23, 1),
"Minimum_Startup_Force": (24, 2),
"CW_Dead_Zone": (26, 1),
"CCW_Dead_Zone": (27, 1),
"Protective_Torque": (37, 1),
"Protection_Time": (38, 1),
# SRAM
"Torque_Enable": (40, 1),
"Acceleration": (41, 1),
"Goal_Position": (42, 2),
"Running_Time": (44, 2),
"Goal_Velocity": (46, 2),
"Lock": (48, 1),
"Present_Position": (56, 2), # read-only
"Present_Velocity": (58, 2), # read-only
"Present_Load": (60, 2), # read-only
"Present_Voltage": (62, 1), # read-only
"Present_Temperature": (63, 1), # read-only
"Sync_Write_Flag": (64, 1), # read-only
"Status": (65, 1), # read-only
"Moving": (66, 1), # read-only
# Factory
"PWM_Maximum_Step": (78, 1),
"Moving_Velocity_Threshold*50": (79, 1),
"DTs": (80, 1), # (ms)
"Minimum_Velocity_Limit*50": (81, 1),
"Maximum_Velocity_Limit*50": (82, 1),
"Acceleration_2": (83, 1), # don't know what that is
}
STS_SMS_SERIES_BAUDRATE_TABLE = {
1_000_000: 0,
500_000: 1,
250_000: 2,
128_000: 3,
115_200: 4,
57_600: 5,
38_400: 6,
19_200: 7,
}
SCS_SERIES_BAUDRATE_TABLE = {
1_000_000: 0,
500_000: 1,
250_000: 2,
128_000: 3,
115_200: 4,
57_600: 5,
38_400: 6,
19_200: 7,
}
MODEL_CONTROL_TABLE = {
"sts_series": STS_SMS_SERIES_CONTROL_TABLE,
"scs_series": SCS_SERIES_CONTROL_TABLE,
"sms_series": STS_SMS_SERIES_CONTROL_TABLE,
"sts3215": STS_SMS_SERIES_CONTROL_TABLE,
"sts3250": STS_SMS_SERIES_CONTROL_TABLE,
"scs0009": SCS_SERIES_CONTROL_TABLE,
"sm8512bl": STS_SMS_SERIES_CONTROL_TABLE,
}
MODEL_RESOLUTION = {
"sts_series": 4096,
"sms_series": 4096,
"scs_series": 1024,
"sts3215": 4096,
"sts3250": 4096,
"sm8512bl": 65536,
"scs0009": 1024,
}
MODEL_BAUDRATE_TABLE = {
"sts_series": STS_SMS_SERIES_BAUDRATE_TABLE,
"sms_series": STS_SMS_SERIES_BAUDRATE_TABLE,
"scs_series": SCS_SERIES_BAUDRATE_TABLE,
"sm8512bl": STS_SMS_SERIES_BAUDRATE_TABLE,
"sts3215": STS_SMS_SERIES_BAUDRATE_TABLE,
"sts3250": STS_SMS_SERIES_BAUDRATE_TABLE,
"scs0009": SCS_SERIES_BAUDRATE_TABLE,
}
# Sign-Magnitude encoding bits
STS_SMS_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": 11,
"Goal_Velocity": 15,
"Present_Velocity": 15,
}
MODEL_ENCODING_TABLE = {
"sts_series": STS_SMS_SERIES_ENCODINGS_TABLE,
"sms_series": STS_SMS_SERIES_ENCODINGS_TABLE,
"scs_series": {},
"sts3215": STS_SMS_SERIES_ENCODINGS_TABLE,
"sts3250": STS_SMS_SERIES_ENCODINGS_TABLE,
"sm8512bl": STS_SMS_SERIES_ENCODINGS_TABLE,
"scs0009": {},
}
SCAN_BAUDRATES = [
4_800,
9_600,
14_400,
19_200,
38_400,
57_600,
115_200,
128_000,
250_000,
500_000,
1_000_000,
]
MODEL_NUMBER_TABLE = {
"sts3215": 777,
"sts3250": 2825,
"sm8512bl": 11272,
"scs0009": 1284,
}
MODEL_PROTOCOL = {
"sts_series": 0,
"sms_series": 0,
"scs_series": 1,
"sts3215": 0,
"sts3250": 0,
"sm8512bl": 0,
"scs0009": 1,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,15 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .optimizers import OptimizerConfig as OptimizerConfig

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.configs.train import TrainPipelineConfig
from lerobot.policies.pretrained import PreTrainedPolicy
def make_optimizer_and_scheduler(
cfg: TrainPipelineConfig, policy: PreTrainedPolicy
) -> tuple[Optimizer, LRScheduler | None]:
"""Generates the optimizer and scheduler based on configs.
Args:
cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs
policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from.
Returns:
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
"""
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
optimizer = cfg.optimizer.build(params)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
return optimizer, lr_scheduler

View File

@@ -0,0 +1,230 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
import draccus
import torch
from safetensors.torch import load_file, save_file
from lerobot.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
)
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
from lerobot.utils.io_utils import deserialize_json_into_object
@dataclass
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
lr: float
weight_decay: float
grad_clip_norm: float
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@classmethod
def default_choice_name(cls) -> str | None:
return "adam"
@abc.abstractmethod
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
NOTE: Multiple optimizers are useful when you have different models to optimize.
For example, you can have one optimizer for the policy and another one for the value function
in reinforcement learning settings.
Returns:
The optimizer or a dictionary of optimizers.
"""
raise NotImplementedError
@OptimizerConfig.register_subclass("adam")
@dataclass
class AdamConfig(OptimizerConfig):
lr: float = 1e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.Adam(params, **kwargs)
@OptimizerConfig.register_subclass("adamw")
@dataclass
class AdamWConfig(OptimizerConfig):
lr: float = 1e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 1e-2
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.AdamW(params, **kwargs)
@OptimizerConfig.register_subclass("sgd")
@dataclass
class SGDConfig(OptimizerConfig):
lr: float = 1e-3
momentum: float = 0.0
dampening: float = 0.0
nesterov: bool = False
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.SGD(params, **kwargs)
@OptimizerConfig.register_subclass("multi_adam")
@dataclass
class MultiAdamConfig(OptimizerConfig):
"""Configuration for multiple Adam optimizers with different parameter groups.
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
Args:
lr: Default learning rate (used if not specified for a group)
weight_decay: Default weight decay (used if not specified for a group)
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
grad_clip_norm: Gradient clipping norm
"""
lr: float = 1e-3
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
"""Build multiple Adam optimizers.
Args:
params_dict: Dictionary mapping parameter group names to lists of parameters
The keys should match the keys in optimizer_groups
Returns:
Dictionary mapping parameter group names to their optimizers
"""
optimizers = {}
for name, params in params_dict.items():
# Get group-specific hyperparameters or use defaults
group_config = self.optimizer_groups.get(name, {})
# Create optimizer with merged parameters (defaults + group-specific)
optimizer_kwargs = {
"lr": group_config.get("lr", self.lr),
"betas": group_config.get("betas", (0.9, 0.999)),
"eps": group_config.get("eps", 1e-5),
"weight_decay": group_config.get("weight_decay", self.weight_decay),
}
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
return optimizers
def save_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> None:
"""Save optimizer state to disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to save the optimizer state.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
optimizer_dir.mkdir(exist_ok=True, parents=True)
_save_single_optimizer_state(opt, optimizer_dir)
else:
# Handle single optimizer
_save_single_optimizer_state(optimizer, save_dir)
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
"""Save a single optimizer's state to disk."""
state = optimizer.state_dict()
param_groups = state.pop("param_groups")
flat_state = flatten_dict(state)
save_file(flat_state, save_dir / OPTIMIZER_STATE)
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""Load optimizer state from disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to load the optimizer state from.
Returns:
The updated optimizer(s) with loaded state.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
loaded_optimizers = {}
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
if optimizer_dir.exists():
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
else:
loaded_optimizers[name] = opt
return loaded_optimizers
else:
# Handle single optimizer
return _load_single_optimizer_state(optimizer, save_dir)
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
"""Load a single optimizer's state from disk."""
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
# Handle case where 'state' key might not exist (for newly created optimizers)
if "state" in state:
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
else:
loaded_state_dict = {"state": {}}
if "param_groups" in current_state_dict:
param_groups = deserialize_json_into_object(
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
)
loaded_state_dict["param_groups"] = param_groups
optimizer.load_state_dict(loaded_state_dict)
return optimizer

View File

@@ -0,0 +1,122 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import math
from dataclasses import asdict, dataclass
from pathlib import Path
import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from lerobot.constants import SCHEDULER_STATE
from lerobot.datasets.utils import write_json
from lerobot.utils.io_utils import deserialize_json_into_object
@dataclass
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
num_warmup_steps: int
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@abc.abstractmethod
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
raise NotImplementedError
@LRSchedulerConfig.register_subclass("diffuser")
@dataclass
class DiffuserSchedulerConfig(LRSchedulerConfig):
name: str = "cosine"
num_warmup_steps: int | None = None
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
from diffusers.optimization import get_scheduler
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
return get_scheduler(**kwargs)
@LRSchedulerConfig.register_subclass("vqbet")
@dataclass
class VQBeTSchedulerConfig(LRSchedulerConfig):
num_warmup_steps: int
num_vqvae_training_steps: int
num_cycles: float = 0.5
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step):
if current_step < self.num_vqvae_training_steps:
return float(1)
else:
adjusted_step = current_step - self.num_vqvae_training_steps
if adjusted_step < self.num_warmup_steps:
return float(adjusted_step) / float(max(1, self.num_warmup_steps))
progress = float(adjusted_step - self.num_warmup_steps) / float(
max(1, num_training_steps - self.num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Used by Physical Intelligence to train Pi0"""
num_warmup_steps: int
num_decay_steps: int
peak_lr: float
decay_lr: float
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
del num_training_steps
def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
return 1 / (self.num_warmup_steps + 1)
frac = 1 - current_step / self.num_warmup_steps
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
def cosine_decay_schedule(current_step):
step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed
if current_step < self.num_warmup_steps:
return linear_warmup_schedule(current_step)
return cosine_decay_schedule(current_step)
return LambdaLR(optimizer, lr_lambda, -1)
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
state_dict = scheduler.state_dict()
write_json(state_dict, save_dir / SCHEDULER_STATE)
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
scheduler.load_state_dict(state_dict)
return scheduler

View File

@@ -0,0 +1,20 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig

View File

@@ -0,0 +1,186 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
@PreTrainedConfig.register_subclass("act")
@dataclass
class ACTConfig(PreTrainedConfig):
"""Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- Either:
- At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. Right now we only support all images having the same shape.
- May optionally work without an "observation.state" key for the proprioceptive robot state.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
chunk_size: The size of the action prediction "chunks" in units of environment steps.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
convolution.
pre_norm: Whether to use "pre-norm" in the transformer blocks.
dim_model: The transformer blocks' main hidden dimension.
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
layers.
feedforward_activation: The activation to use in the transformer block's feed-forward layers.
n_encoder_layers: The number of transformer layers to use for the transformer encoder.
n_decoder_layers: The number of transformer layers to use for the transformer decoder.
use_vae: Whether to use a variational objective during training. This introduces another transformer
which is used as the VAE's encoder (not to be confused with the transformer encoder - see
documentation in the policy class).
latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
1 when using this feature, as inference needs to happen at every step to form an ensemble. For
more information on how ensembling works, please see `ACTTemporalEnsembler`.
dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
"""
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 100
n_action_steps: int = 100
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Architecture.
# Vision backbone.
vision_backbone: str = "resnet18"
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
replace_final_stride_with_dilation: int = False
# Transformer layers.
pre_norm: bool = False
dim_model: int = 512
n_heads: int = 8
dim_feedforward: int = 3200
feedforward_activation: str = "relu"
n_encoder_layers: int = 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: int = 1
# VAE.
use_vae: bool = True
latent_dim: int = 32
n_vae_encoder_layers: int = 4
# Inference.
# Note: the value used in ACT when temporal ensembling is enabled is 0.01.
temporal_ensemble_coeff: float | None = None
# Training and loss computation.
dropout: float = 0.1
kl_weight: float = 10.0
# Training preset
optimizer_lr: float = 1e-5
optimizer_weight_decay: float = 1e-4
optimizer_lr_backbone: float = 1e-5
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
raise NotImplementedError(
"`n_action_steps` must be 1 when using temporal ensembling. This is "
"because the policy needs to be queried every step to compute the ensembled action."
)
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,769 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action Chunking Transformer Policy
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://huggingface.co/papers/2304.13705).
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
"""
import math
from collections import deque
from itertools import chain
from typing import Callable
import einops
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.constants import ACTION, OBS_IMAGES
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
class ACTPolicy(PreTrainedPolicy):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://huggingface.co/papers/2304.13705, code: https://github.com/tonyzhaozh/act)
"""
config_class = ACTConfig
name = "act"
def __init__(
self,
config: ACTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = ACT(config)
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset()
def get_optim_params(self) -> dict:
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
# Should we remove this and just `return self.parameters()`?
return [
{
"params": [
p
for n, p in self.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in self.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": self.config.optimizer_lr_backbone,
},
]
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_coeff is not None:
self.temporal_ensembler.reset()
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
if self.config.temporal_ensemble_coeff is not None:
actions = self.predict_action_chunk(batch)
action = self.temporal_ensembler.update(actions)
return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
actions = self.model(batch)[0]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://huggingface.co/papers/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight
else:
loss = l1_loss
return loss, loss_dict
class ACTTemporalEnsembler:
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
"""Temporal ensembling as described in Algorithm 2 of https://huggingface.co/papers/2304.13705.
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
coefficient works:
- Setting it to 0 uniformly weighs all actions.
- Setting it positive gives more weight to older actions.
- Setting it negative gives more weight to newer actions.
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
results in older actions being weighed more highly than newer actions (the experiments documented in
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
detrimental: doing so aggressively may diminish the benefits of action chunking).
Here we use an online method for computing the average rather than caching a history of actions in
order to compute the average offline. For a simple 1D sequence it looks something like:
```
import torch
seq = torch.linspace(8, 8.5, 100)
print(seq)
m = 0.01
exp_weights = torch.exp(-m * torch.arange(len(seq)))
print(exp_weights)
# Calculate offline
avg = (exp_weights * seq).sum() / exp_weights.sum()
print("offline", avg)
# Calculate online
for i, item in enumerate(seq):
if i == 0:
avg = item
continue
avg *= exp_weights[:i].sum()
avg += item * exp_weights[i]
avg /= exp_weights[:i+1].sum()
print("online", avg)
```
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()
def reset(self):
"""Resets the online computation variables."""
self.ensembled_actions = None
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
self.ensembled_actions_count = None
def update(self, actions: Tensor) -> Tensor:
"""
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self.ensembled_actions = actions.clone()
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
self.ensembled_actions[:, 0],
self.ensembled_actions[:, 1:],
self.ensembled_actions_count[1:],
)
return action
class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
model that encodes the target data (a sequence of actions), and the condition (the robot
joint-space).
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
have an option to train this model without the variational objective (in which case we drop the
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
Transformer
Used alone for inference
(acts as VAE decoder
during training)
┌───────────────────────┐
│ Outputs │
│ ▲ │
│ ┌─────►┌───────┐ │
┌──────┐ │ │ │Transf.│ │
│ │ │ ├─────►│decoder│ │
┌────┴────┐ │ │ │ │ │ │
│ │ │ │ ┌───┴───┬─►│ │ │
│ VAE │ │ │ │ │ └───────┘ │
│ encoder │ │ │ │Transf.│ │
│ │ │ │ │encoder│ │
└───▲─────┘ │ │ │ │ │
│ │ │ └▲──▲─▲─┘ │
│ │ │ │ │ │ │
inputs └─────┼──┘ │ image emb. │
│ state emb. │
└───────────────────────┘
"""
def __init__(self, config: ACTConfig):
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
super().__init__()
self.config = config
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.config.robot_state_feature:
self.vae_encoder_robot_state_input_proj = nn.Linear(
self.config.robot_state_feature.shape[0], config.dim_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
self.config.action_feature.shape[0],
config.dim_model,
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.config.robot_state_feature:
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
)
# Backbone for image feature extraction.
if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
self.decoder = ACTDecoder(config)
# Transformer encoder input projections. The tokens will be structured like
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.config.robot_state_feature:
self.encoder_robot_state_input_proj = nn.Linear(
self.config.robot_state_feature.shape[0], config.dim_model
)
if self.config.env_state_feature:
self.encoder_env_state_input_proj = nn.Linear(
self.config.env_state_feature.shape[0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
if self.config.image_features:
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
n_1d_tokens = 1 # for the latent
if self.config.robot_state_feature:
n_1d_tokens += 1
if self.config.env_state_feature:
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
self._reset_parameters()
def _reset_parameters(self):
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
{
[robot_state_feature] (optional): (B, state_dim) batch of robot states.
[image_features]: (B, n_cameras, C, H, W) batch of images.
AND/OR
[env_state_feature]: (B, env_dim) batch of environment states.
[action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
}
Returns:
(B, chunk_size, action_dim) batch of action sequences
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
latent dimension.
"""
if self.config.use_vae and self.training:
assert "action" in batch, (
"actions must be provided when using the variational objective in training mode."
)
if "observation.images" in batch:
batch_size = batch["observation.images"][0].shape[0]
else:
batch_size = batch["observation.environment_state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
# Prepare fixed positional embedding.
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the
# sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token.
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.config.robot_state_feature else 1),
False,
device=batch["observation.state"].device,
)
key_padding_mask = torch.cat(
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
) # (bs, seq+1 or 2)
# Forward pass through VAE encoder to get the latent PDF parameters.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2),
pos_embed=pos_embed.permute(1, 0, 2),
key_padding_mask=key_padding_mask,
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.config.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
# Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings.
if self.config.image_features:
all_cam_features = []
all_cam_pos_embeds = []
# For a list of images, the H and W may vary but H*W is constant.
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim).
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=encoder_in_pos_embed.dtype,
device=encoder_in_pos_embed.device,
)
decoder_out = self.decoder(
decoder_in,
encoder_out,
encoder_pos_embed=encoder_in_pos_embed,
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
)
# Move back to (B, S, C).
decoder_out = decoder_out.transpose(0, 1)
actions = self.action_head(decoder_out)
return actions, (mu, log_sigma_x2)
class ACTEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
super().__init__()
self.is_vae_encoder = is_vae_encoder
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
x = self.norm(x)
return x
class ACTEncoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
self.dropout = nn.Dropout(config.dropout)
self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
self.norm1 = nn.LayerNorm(config.dim_model)
self.norm2 = nn.LayerNorm(config.dim_model)
self.dropout1 = nn.Dropout(config.dropout)
self.dropout2 = nn.Dropout(config.dropout)
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
q = k = x if pos_embed is None else x + pos_embed
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)
x = x[0] # note: [0] to select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
x = self.norm2(x)
else:
x = self.norm1(x)
skip = x
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = skip + self.dropout2(x)
if not self.pre_norm:
x = self.norm2(x)
return x
class ACTDecoder(nn.Module):
def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
self.norm = nn.LayerNorm(config.dim_model)
def forward(
self,
x: Tensor,
encoder_out: Tensor,
decoder_pos_embed: Tensor | None = None,
encoder_pos_embed: Tensor | None = None,
) -> Tensor:
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
)
if self.norm is not None:
x = self.norm(x)
return x
class ACTDecoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
self.dropout = nn.Dropout(config.dropout)
self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
self.norm1 = nn.LayerNorm(config.dim_model)
self.norm2 = nn.LayerNorm(config.dim_model)
self.norm3 = nn.LayerNorm(config.dim_model)
self.dropout1 = nn.Dropout(config.dropout)
self.dropout2 = nn.Dropout(config.dropout)
self.dropout3 = nn.Dropout(config.dropout)
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
return tensor if pos_embed is None else tensor + pos_embed
def forward(
self,
x: Tensor,
encoder_out: Tensor,
decoder_pos_embed: Tensor | None = None,
encoder_pos_embed: Tensor | None = None,
) -> Tensor:
"""
Args:
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
cross-attending with.
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
Returns:
(DS, B, C) tensor of decoder output features.
"""
skip = x
if self.pre_norm:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
x = self.norm2(x)
else:
x = self.norm1(x)
skip = x
x = self.multihead_attn(
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
value=encoder_out,
)[0] # select just the output, not the attention weights
x = skip + self.dropout2(x)
if self.pre_norm:
skip = x
x = self.norm3(x)
else:
x = self.norm2(x)
skip = x
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = skip + self.dropout3(x)
if not self.pre_norm:
x = self.norm3(x)
return x
def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
"""1D sinusoidal positional embeddings as in Attention is All You Need.
Args:
num_positions: Number of token positions required.
Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).
"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
class ACTSinusoidalPositionEmbedding2d(nn.Module):
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
for the vertical direction, and 1/W for the horizontal direction.
"""
def __init__(self, dimension: int):
"""
Args:
dimension: The desired dimension of the embeddings.
"""
super().__init__()
self.dimension = dimension
self._two_pi = 2 * math.pi
self._eps = 1e-6
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
self._temperature = 10000
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
Returns:
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
"""
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
y_range = not_mask.cumsum(1, dtype=torch.float32)
x_range = not_mask.cumsum(2, dtype=torch.float32)
# "Normalize" the position index such that it ranges in [0, 2π].
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
# are non-zero by construction. This is an artifact of the original code.
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
)
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
return pos_embed
def get_activation_fn(activation: str) -> Callable:
"""Return an activation function given a string."""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")

View File

@@ -0,0 +1,237 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("diffusion")
@dataclass
class DiffusionConfig(PreTrainedConfig):
"""Configuration class for DiffusionPolicy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- Either:
- At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views. Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
network. This is the output dimension of that network, i.e., the embedding dimension.
use_film_scale_modulation: FiLM (https://huggingface.co/papers/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step.
beta_end: Beta value for the last forward-diffusion step.
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
"epsilon" has been shown to work better in many deep neural network settings.
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
denoising step at inference time. WARNING: you will need to make sure your action-space is
normalized to fit within this range.
clip_sample_range: The magnitude of the clipping range as described above.
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
to False as the original Diffusion Policy implementation does the same.
"""
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
n_action_steps: int = 8
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
use_separate_rgb_encoder_per_camera: bool = False
# Unet.
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5
n_groups: int = 8
diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True
# Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
beta_end: float = 0.02
prediction_type: str = "epsilon"
clip_sample: bool = True
clip_sample_range: float = 1.0
# Inference
num_inference_steps: int | None = None
# Loss computation
do_mask_loss_for_padding: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 500
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
if self.horizon % downsampling_factor != 0:
raise ValueError(
"The horizon should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
)
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,767 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
TODO(alexander-soare):
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
"""
import math
from collections import deque
from typing import Callable
import einops
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
get_output_shape,
populate_queues,
)
class DiffusionPolicy(PreTrainedPolicy):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://huggingface.co/papers/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
"""
config_class = DiffusionConfig
name = "diffusion"
def __init__(
self,
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
self.diffusion = DiffusionModel(config)
self.reset()
def get_optim_params(self) -> dict:
return self.diffusion.parameters()
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
underlying diffusion model. Here's how it works:
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
copied `n_obs_steps` times to fill the cache).
- The diffusion model generates `horizon` steps worth of actions.
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
Schematically this looks like:
----------------------------------------------------------------------------------------------
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
----------------------------------------------------------------------------------------------
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
return loss, None
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")
class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
self.config = config
# Build observation encoders (depending on which observations are provided).
global_cond_dim = self.config.robot_state_feature.shape[0]
if self.config.image_features:
num_images = len(self.config.image_features)
if self.config.use_separate_rgb_encoder_per_camera:
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
self.rgb_encoder = nn.ModuleList(encoders)
global_cond_dim += encoders[0].feature_dim * num_images
else:
self.rgb_encoder = DiffusionRgbEncoder(config)
global_cond_dim += self.rgb_encoder.feature_dim * num_images
if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type,
)
if config.num_inference_steps is None:
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
else:
self.num_inference_steps = config.num_inference_steps
# ========= inference ============
def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Sample prior.
sample = torch.randn(
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
dtype=dtype,
device=device,
generator=generator,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
# Predict model output.
model_output = self.unet(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
global_cond_feats = [batch[OBS_STATE]]
# Extract image features.
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
img_features_list = torch.cat(
[
encoder(images)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
]
)
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
global_cond_feats.append(img_features)
if self.config.env_state_feature:
global_cond_feats.append(batch[OBS_ENV_STATE])
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have:
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
}
"""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# run sampling
actions = self.conditional_sample(batch_size, global_cond=global_cond)
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
assert "observation.images" in batch or "observation.environment_state" in batch
n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# Forward diffusion.
trajectory = batch["action"]
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(trajectory.shape[0],),
device=trajectory.device,
).long()
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
# Compute the loss.
# The target is either the original trajectory, or the noise.
if self.config.prediction_type == "epsilon":
target = eps
elif self.config.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if self.config.do_mask_loss_for_padding:
if "action_is_pad" not in batch:
raise ValueError(
"You need to provide 'action_is_pad' in the batch when "
f"{self.config.do_mask_loss_for_padding=}."
)
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()
class SpatialSoftmax(nn.Module):
"""
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation.
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
-----------------------------------------------------
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
| ... | ... | ... | ... |
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
-----------------------------------------------------
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
linear mapping (in_channels, H, W) -> (num_kp, H, W).
"""
def __init__(self, input_shape, num_kp=None):
"""
Args:
input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
"""
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: Tensor) -> Tensor:
"""
Args:
features: (B, C, H, W) input feature maps.
Returns:
(B, K, 2) image-space coordinates of keypoints.
"""
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class DiffusionRgbEncoder(nn.Module):
"""Encodes an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
"""
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if config.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
# Set up backbone.
backbone_model = getattr(torchvision.models, config.vision_backbone)(
weights=config.pretrained_backbone_weights
)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if config.use_group_norm:
if config.pretrained_backbone_weights:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.image_features`.
# Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, C, H, W) image tensor with pixel values in [0, 1].
Returns:
(B, D) image feature.
"""
# Preprocess: maybe crop (if it was set up in the __init__).
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
else:
# Always use center crop for eval.
x = self.center_crop(x)
# Extract backbone feature.
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
# Final linear layer with non-linearity.
x = self.relu(self.out(x))
return x
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
root_module: The module for which the submodules need to be replaced
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
func: Takes a module as an argument and returns a new module to replace it with.
Returns:
The root module with its submodules replaced.
"""
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
parent_module = root_module.get_submodule(".".join(parents))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
class DiffusionSinusoidalPosEmb(nn.Module):
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class DiffusionConv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class DiffusionConditionalUnet1d(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning.
Note: this removes local conditioning as compared to the original diffusion policy code.
"""
def __init__(self, config: DiffusionConfig, global_cond_dim: int):
super().__init__()
self.config = config
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
)
# The FiLM conditioning dimension.
cond_dim = config.diffusion_step_embed_dim + global_cond_dim
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
)
# Unet encoder.
common_res_block_kwargs = {
"cond_dim": cond_dim,
"kernel_size": config.kernel_size,
"n_groups": config.n_groups,
"use_film_scale_modulation": config.use_film_scale_modulation,
}
self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
self.down_modules.append(
nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
)
)
# Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
]
)
# Unet decoder.
self.up_modules = nn.ModuleList([])
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
self.up_modules.append(
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
)
)
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
"""
Args:
x: (B, T, input_dim) tensor for input to the Unet.
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
global_cond: (B, global_cond_dim)
output: (B, T, input_dim)
Returns:
(B, T, input_dim) diffusion model prediction.
"""
# For 1D convolutions we'll need feature dimension first.
x = einops.rearrange(x, "b t d -> b d t")
timesteps_embed = self.diffusion_step_encoder(timestep)
# If there is a global conditioning feature, concatenate it to the timestep embedding.
if global_cond is not None:
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
else:
global_feature = timesteps_embed
# Run encoder, keeping track of skip features to pass to the decoder.
encoder_skip_features: list[Tensor] = []
for resnet, resnet2, downsample in self.down_modules:
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
encoder_skip_features.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
# Run decoder, using the skip features from the encoder.
for resnet, resnet2, upsample in self.up_modules:
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, "b d t -> b t d")
return x
class DiffusionConditionalResidualBlock1d(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__(
self,
in_channels: int,
out_channels: int,
cond_dim: int,
kernel_size: int = 3,
n_groups: int = 8,
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
# FiLM just modulates bias).
use_film_scale_modulation: bool = False,
):
super().__init__()
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://huggingface.co/papers/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
"""
Args:
x: (B, in_channels, T)
cond: (B, cond_dim)
Returns:
(B, out_channels, T)
"""
out = self.conv1(x)
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
if self.use_film_scale_modulation:
# Treat the embedding as a list of scales and biases.
scale = cond_embed[:, : self.out_channels]
bias = cond_embed[:, self.out_channels :]
out = scale * out + bias
else:
# Treat the embedding as biases.
out = out + cond_embed
out = self.conv2(out)
out = out + self.residual_conv(x)
return out

View File

@@ -0,0 +1,178 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from torch import nn
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
def get_policy_class(name: str) -> PreTrainedPolicy:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy
elif name == "diffusion":
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy
elif name == "act":
from lerobot.policies.act.modeling_act import ACTPolicy
return ACTPolicy
elif name == "vqbet":
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy
elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "pi0fast":
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy
return SACPolicy
elif name == "reward_classifier":
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
return Classifier
elif name == "smolvla":
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
"""Make an instance of a policy class.
This function exists because (for now) we need to parse features from either a dataset or an environment
in order to properly dimension and instantiate a policy for that dataset or environment.
Args:
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
provided if ds_meta is not. Defaults to None.
Raises:
ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
Returns:
PreTrainedPolicy: _description_
"""
if bool(ds_meta) == bool(env_cfg):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
# you want this op to be added in priority during the prototype phase of this feature, please comment on
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS.
if cfg.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
)
policy_cls = get_policy_class(cfg.type)
kwargs = {}
if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
kwargs["dataset_stats"] = ds_meta.stats
else:
if not cfg.pretrained_path:
logging.warning(
"You are instantiating a policy from scratch and its features are parsed from an environment "
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
policy = policy_cls.from_pretrained(**kwargs)
else:
# Make a fresh policy.
policy = policy_cls(**kwargs)
policy.to(cfg.device)
assert isinstance(policy, nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead")
return policy

View File

@@ -0,0 +1,420 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def create_stats_buffers(
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Args: (see Normalize and Unnormalize)
Returns:
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
assert isinstance(norm_mode, NormalizationMode)
shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
# we assert they are not infinity anymore.
buffer = {}
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
"std": nn.Parameter(std, requires_grad=False),
}
)
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False),
}
)
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers
def _no_stats_error_str(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
"pretrained model."
)
class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# TODO: Remove this shallow copy
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(norm_mode)
return batch
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(norm_mode)
return batch
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
module: nn.Module,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> None:
"""Register statistics buffers (mean/std or min/max) on the given *module*.
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
but is factored out so it can be reused by both classes and stay in sync.
"""
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
shape: tuple[int, ...] = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# reduce spatial dimensions, keep channel dimension only
c, *_ = shape
shape = (c, 1, 1)
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.full(shape, torch.inf, dtype=torch.float32)
std = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"]
std_data = stats[key]["std"]
if isinstance(mean_data, torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_mean", mean)
module.register_buffer(f"{prefix}_std", std)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"]
max_data = stats[key]["max"]
if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_min", min_val)
module.register_buffer(f"{prefix}_max", max_val)
continue
raise ValueError(norm_mode)
class NormalizeBuffer(nn.Module):
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
batch[key] = batch[key] * 2 - 1
continue
raise ValueError(norm_mode)
return batch
class UnnormalizeBuffer(nn.Module):
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max_val - min_val) + min_val
continue
raise ValueError(norm_mode)
return batch

View File

@@ -0,0 +1,149 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
@PreTrainedConfig.register_subclass("pi0")
@dataclass
class PI0Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Projector
proj_width: int = 1024
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
attention_implementation: str = "eager" # or fa2, flex
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 2.5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
# TODO: Add EMA
def __post_init__(self):
super().__post_init__()
# TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
)
def validate_features(self) -> None:
# TODO: implement value error
# if not self.image_features and not self.env_state_feature:
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,82 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy
torch.backends.cudnn.benchmark = True
def main():
device = "cuda"
dataset_repo_id = "danaaubakirova/koch_test"
# model_name = "pi0_base"
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_torch_dir = "lerobot/pi0"
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=1,
)
batch = next(iter(dataloader))
# To device
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device=device, dtype=torch.float32)
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, ds_meta=dataset.meta)
# policy = torch.compile(policy, mode="reduce-overhead")
warmup_iters = 10
benchmark_iters = 30
# Warmup
for _ in range(warmup_iters):
torch.cuda.synchronize()
policy.select_action(batch)
policy.reset()
torch.cuda.synchronize()
# Benchmark
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(benchmark_iters):
policy.select_action(batch)
policy.reset()
end_event.record()
# Synchronize and measure time
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time_per_iter = elapsed_time_ms / benchmark_iters
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
if __name__ == "__main__":
with torch.inference_mode():
main()

View File

@@ -0,0 +1,131 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pickle
from pathlib import Path
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy
def display(tensor: torch.Tensor):
if tensor.dtype == torch.bool:
tensor = tensor.float()
print(f"Shape: {tensor.shape}")
print(f"Mean: {tensor.mean().item()}")
print(f"Std: {tensor.std().item()}")
print(f"Min: {tensor.min().item()}")
print(f"Max: {tensor.max().item()}")
def main():
num_motors = 14
device = "cuda"
# model_name = "pi0_aloha_towel"
model_name = "pi0_aloha_sim"
if model_name == "pi0_aloha_towel":
dataset_repo_id = "lerobot/aloha_static_towel"
else:
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
save_dir = Path(f"../openpi/data/{model_name}/save")
with open(save_dir / "example.pkl", "rb") as f:
example = pickle.load(f)
with open(save_dir / "outputs.pkl", "rb") as f:
outputs = pickle.load(f)
with open(save_dir / "noise.pkl", "rb") as f:
noise = pickle.load(f)
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
norm_stats = json.load(f)
# Override stats
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
)
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
)
# Create LeRobot batch from Jax
batch = {}
for cam_key, uint_chw_array in example["images"].items():
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch["observation.state"] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"]
if model_name == "pi0_aloha_towel":
del batch["observation.images.cam_low"]
elif model_name == "pi0_aloha_sim":
batch["observation.images.top"] = batch["observation.images.cam_high"]
del batch["observation.images.cam_high"]
# Batchify
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].unsqueeze(0)
elif isinstance(batch[key], str):
batch[key] = [batch[key]]
else:
raise ValueError(f"{key}, {batch[key]}")
# To device
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device=device, dtype=torch.float32)
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
from lerobot import policies # noqa
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, dataset_meta)
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
# loss_dict["loss"].backward()
# print("losses")
# display(loss_dict["losses_after_forward"])
# print("pi_losses")
# display(pi_losses)
actions = []
for _ in range(50):
action = policy.select_action(batch, noise=noise)
actions.append(action)
actions = torch.stack(actions, dim=1)
pi_actions = batch["action"]
print("actions")
display(actions)
print()
print("pi_actions")
display(pi_actions)
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,84 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import GemmaConfig, PaliGemmaConfig
def get_paligemma_config(precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
image_size = 224 # image_sizes[variant]
patch_size = 14
num_image_tokens = (image_size**2) // (patch_size**2)
config["image_token_index"] = 257152
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 2048,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 16384,
"is_encoder_decoder": False,
}
vision_config = {
"torch_dtype": precision,
"image_size": image_size,
"patch_size": patch_size,
"num_image_tokens": num_image_tokens,
"hidden_size": 1152,
"intermediate_size": 4304,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
return final_config
def get_gemma_config(precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
config["image_token_index"] = 257152
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 1024,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 4096,
"is_encoder_decoder": False,
}
final_config = GemmaConfig()
final_config.update(text_config)
return final_config

View File

@@ -0,0 +1,437 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert pi0 parameters from Jax to Pytorch
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
and install the required libraries.
```bash
cd ~/code/openpi
source .venv/bin/activate
```
Example downloading parameters:
```bash
python
>>> import openpi.shared.download as download
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
>>> download.maybe_download(path)
```
Converting pi0_base:
```python
python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
```
```python
python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
```
"""
import argparse
import pathlib
import jax
import numpy as np
import orbax.checkpoint as ocp
import torch
from jax.sharding import SingleDeviceSharding
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.conversion_scripts.conversion_utils import (
get_gemma_config,
get_paligemma_config,
)
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
def slice_paligemma_state_dict(state_dict, config):
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
# fmt: off
# patch embeddings
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
3, 2, 0, 1
)
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
# positional embeddings
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
-1, config.vision_config.hidden_size
)
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
for i in range(config.vision_config.num_hidden_layers):
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
# multimodal projector
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
# text decoder (gemma)
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
for i in range(config.text_config.num_hidden_layers):
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
# fmt: on
expert_dict = {}
final_state_dict = {}
for key, value in state_dict.items():
if key not in [
f"llm/final_norm_1/scale{suffix}",
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
f"llm/layers/attn/kv_einsum_1/w{suffix}",
f"llm/layers/attn/q_einsum_1/w{suffix}",
f"llm/layers/mlp_1/gating_einsum{suffix}",
f"llm/layers/mlp_1/linear{suffix}",
f"llm/layers/pre_attention_norm_1/scale{suffix}",
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
]:
final_state_dict[key] = torch.from_numpy(value)
else:
expert_dict[key] = value
return final_state_dict, expert_dict
def slice_gemma_state_dict(state_dict, config, num_expert=1):
# fmt: off
# text decoder (gemma)
# no embedding vector, the expert just has the decoder layers
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
for i in range(config.num_hidden_layers):
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
# fmt: on
final_state_dict = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
final_state_dict[key] = torch.from_numpy(value)
else:
final_state_dict[key] = value
return final_state_dict
def flatten_for_memory(tree, parent_key=""):
out = {}
for k, v in tree.items():
new_key = f"{parent_key}/{k}" if parent_key else k
if isinstance(v, dict):
out.update(flatten_for_memory(v, new_key))
else:
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
return out
def flatten_for_npz(tree, parent_key=""):
out = {}
for k, v in tree.items():
new_key = f"{parent_key}/{k}" if parent_key else k
if isinstance(v, dict):
out.update(flatten_for_npz(v, new_key))
else:
# bf16/f32 here?
out[new_key] = np.array(v)
return out
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
params_path = pathlib.Path(checkpoint_dir).resolve()
checkpointer = ocp.PyTreeCheckpointer()
metadata = checkpointer.metadata(params_path)
print("Metadata keys:", list(metadata.keys()))
params_name = "params"
item = {params_name: metadata[params_name]}
device = jax.local_devices()[0] # Use the first local device
sharding = SingleDeviceSharding(device)
restored = checkpointer.restore(
params_path,
ocp.args.PyTreeRestore(
item=item,
restore_args=jax.tree_util.tree_map(
lambda _: ocp.ArrayRestoreArgs(
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
sharding=sharding,
),
item,
),
transforms={},
),
)
params = restored[params_name]
# get params for PaliGemma
pali_params = params["PaliGemma"]
del params["PaliGemma"]
pali_params_flat = flatten_for_npz(pali_params)
return {"paligemma_params": pali_params_flat, "projection_params": params}
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
"""Update dictionary keys by adding a prefix."""
return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
# Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params
keys = [
"state_proj",
"action_in_proj",
"action_out_proj",
"action_time_mlp_in",
"action_time_mlp_out",
]
projection_params = {}
for key in keys:
kernel_params = initial_params["projection_params"][key]["kernel"]
bias_params = initial_params["projection_params"][key]["bias"]
if isinstance(kernel_params, dict):
weight = kernel_params["value"]
bias = bias_params["value"]
else:
weight = kernel_params
bias = bias_params
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
# Process PaliGemma weights
paligemma_config = get_paligemma_config(precision)
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
initial_params["paligemma_params"], paligemma_config
)
# Process Gemma weights (at this stage they are unused)
gemma_config = get_gemma_config(precision)
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
# Instantiate model from configs
if "pi0_aloha_sim" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=2,
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=False,
)
elif "pi0_aloha_towel" in checkpoint_dir:
pi0_config = PI0Config(
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=True,
)
elif "pi0_base" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=0,
adapt_to_pi_aloha=False,
use_delta_joint_actions_aloha=False,
)
else:
raise ValueError()
# gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.")
# load state dict
torch_dtype = PRECISIONS[precision]
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
pi0_model = pi0_model.to(torch_dtype)
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
pi0_model.save_pretrained(output_path, safe_serialization=True)
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
# assert that model loads properly
del pi0_model
PI0Policy.from_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_dir",
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
type=str,
help="Path to the ocdbt checkpoint",
)
parser.add_argument(
"--precision",
choices=["float32", "bfloat16", "float16"],
default="float32",
type=str,
help="Precision identifier for model conversion - should match the base checkpoint precision.",
)
# tokenizer is identical to paligemma, it appears
parser.add_argument(
"--tokenizer_hub_id",
default="google/paligemma-3b-pt-224",
type=str,
help="Hub path to the tokenizer to save",
)
parser.add_argument(
"--output_path",
required=True,
type=str,
help="Path to save converted weights to",
)
args = parser.parse_args()
convert_pi0_checkpoint(
checkpoint_dir=args.checkpoint_dir,
precision=args.precision,
tokenizer_id=args.tokenizer_hub_id,
output_path=args.output_path,
)

View File

@@ -0,0 +1,141 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F # noqa: N812
from packaging.version import Version
if Version(torch.__version__) > Version("2.5.0"):
# Ffex attention is only available from torch 2.5 onwards
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
_round_up_to_multiple,
create_block_mask,
create_mask,
flex_attention,
)
# @torch.compile(dynamic=False)
def flex_attention_forward(
attention_mask: torch.Tensor,
batch_size: int,
head_dim: int,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
scaling=None,
):
"""
This is defined out of classes to make compile happy.
"""
original_dtype = query_states.dtype
num_att_heads = 8
num_key_value_heads = 1
num_key_value_groups = num_att_heads // num_key_value_heads
key_states = key_states[:, :, :, None, :]
key_states = key_states.expand(
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :]
value_states = value_states.expand(
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states = query_states.to(torch.float32)
key_states = key_states.to(torch.float32)
value_states = value_states.to(torch.float32)
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
def mask_mod(b, h, q_idx, kv_idx):
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
return precomputed_mask[b][h][q_idx][kv_idx]
return mask_mod
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
block_size = 128
q_len_rounded = _round_up_to_multiple(q_len, block_size)
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
# *CRITICAL* we do need to expand here, else we get a CUDA index error
pad_q = q_len_rounded - q_len
pad_k = kv_len_rounded - kv_len
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
mask_4d = create_mask(
mod_fn=mask_mod_fn_orig,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
device=causal_mask.device,
_compile=False,
)
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
block_mask = create_block_mask(
mask_mod=mask_mod_fn_padded,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
BLOCK_SIZE=block_size,
device=causal_mask.device,
_compile=False,
)
# mask is applied inside the kernel, ideally more efficiently than score_mod.
attn_output, attention_weights = flex_attention(
query_states,
key_states,
value_states,
block_mask=block_mask,
enable_gqa=True, # because we shaped query/key states for GQA
scale=head_dim**-0.5 if scaling is None else scaling,
return_lse=True,
)
attn_output = attn_output.to(dtype=original_dtype)
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
attn_output = attn_output.reshape(
batch_size,
-1,
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
)
return attn_output

View File

@@ -0,0 +1,737 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
π0: A Vision-Language-Action Flow Model for General Robot Control
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Install pi0 extra dependencies:
```bash
pip install -e ".[pi0]"
```
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
```bash
python -m lerobot.scripts.train \
--policy.path=lerobot/pi0 \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
pretrained with VLM default parameters before pi0 finetuning:
```bash
python -m lerobot.scripts.train \
--policy.type=pi0 \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of using the pi0 pretrained model outside LeRobot training framework:
```python
policy = Pi0Policy.from_pretrained("lerobot/pi0")
```
"""
import math
from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoTokenizer
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.paligemma_with_expert import (
PaliGemmaWithExpertConfig,
PaliGemmaWithExpertModel,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb
def sample_beta(alpha, beta, bsize, device):
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
return gamma1 / (gamma1 + gamma2)
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def resize_with_pad(img, width, height, pad_value=-1):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] == new_dim:
return vector
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
new_vector[..., :current_dim] = vector
return new_vector
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class PI0Policy(PreTrainedPolicy):
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
config_class = PI0Config
name = "pi0"
def __init__(
self,
config: PI0Config,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FlowMatching(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def get_optim_params(self) -> dict:
return self.parameters()
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0")
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("action_is_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
# For backward pass
loss = losses.mean()
# For logging
loss_dict["l2_loss"] = loss.item()
return loss, loss_dict
def prepare_images(self, batch):
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
"""
images = []
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
images.append(img)
img_masks.append(mask)
# Create image features not present in the batch
# as fully 0 padded images.
for num_empty_cameras in range(len(missing_img_keys)):
if num_empty_cameras >= self.config.empty_cameras:
break
img = torch.ones_like(img) * -1
mask = torch.zeros_like(mask)
images.append(img)
img_masks.append(mask)
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
# PaliGemma prompt has to end with a new line
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding="max_length",
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
def prepare_state(self, batch):
"""Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state
def prepare_action(self, batch):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
class PI0FlowMatching(nn.Module):
"""
π0: A Vision-Language-Action Flow Model for General Robot Control
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
┌──────────────────────────────┐
│ actions │
│ ▲ │
│ ┌┴─────┐ │
│ kv cache │Gemma │ │
│ ┌──────────►│Expert│ │
│ │ │ │ │
│ ┌┴────────┐ │x 10 │ │
│ │ │ └▲──▲──┘ │
│ │PaliGemma│ │ │ │
│ │ │ │ robot state │
│ │ │ noise │
│ └▲──▲─────┘ │
│ │ │ │
│ │ image(s) │
│ language tokens │
└──────────────────────────────┘
"""
def __init__(self, config):
super().__init__()
self.config = config
paligemma_with_export_config = PaliGemmaWithExpertConfig(
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation,
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
# Projections are float32
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
self.set_requires_grad()
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
def sample_noise(self, shape, device):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for PaliGemma transformer processing.
"""
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
embs = []
pad_masks = []
att_masks = []
# TODO: remove for loop
for (
img,
img_mask,
) in zip(images, img_masks, strict=False):
img_emb = self.paligemma_with_expert.embed_image(img)
img_emb = img_emb.to(dtype=torch.bfloat16)
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
embs.append(img_emb)
pad_masks.append(img_mask)
# Create attention masks so that image tokens attend to each other
att_masks += [0] * num_img_embs
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
# Normalize language embeddings
lang_emb_dim = lang_emb.shape[-1]
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
embs.append(lang_emb)
pad_masks.append(lang_masks)
# full attention between image and language inputs
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
# Embed state
state_emb = self.state_proj(state)
state_emb = state_emb.to(dtype=torch.bfloat16)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
dtype = state_emb.dtype
device = state_emb.device
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1]
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
)
time_emb = time_emb.type(dtype=dtype)
# Fuse timestep + action information using an MLP
action_emb = self.action_in_proj(noisy_actions)
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
action_time_emb = F.silu(action_time_emb) # swish == silu
action_time_emb = self.action_time_mlp_out(action_time_emb)
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
fill_kv_cache=False,
)
suffix_out = suffix_out[:, -self.config.n_action_steps :]
# Original openpi code, upcast attention output
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
if noise is None:
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step
x_t += dt * v_t
time += dt
return x_t
def denoise_step(
self,
state,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=self.config.use_cache,
fill_kv_cache=False,
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.n_action_steps :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
return v_t

View File

@@ -0,0 +1,421 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
import torch.version
from pytest import Cache
from torch import nn
from transformers import (
AutoConfig,
GemmaForCausalLM,
PaliGemmaForConditionalGeneration,
PretrainedConfig,
PreTrainedModel,
)
from transformers.models.auto import CONFIG_MAPPING
from lerobot.policies.pi0.flex_attention import flex_attention_forward
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
class PaliGemmaWithExpertConfig(PretrainedConfig):
model_type = "PaliGemmaWithExpertModel"
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
def __init__(
self,
paligemma_config: dict | None = None,
gemma_expert_config: dict | None = None,
freeze_vision_encoder: bool = True,
train_expert_only: bool = True,
attention_implementation: str = "eager",
**kwargs,
):
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_implementation = attention_implementation
if paligemma_config is None:
# Default config from Pi0
self.paligemma_config = CONFIG_MAPPING["paligemma"](
transformers_version="4.48.1",
_vocab_size=257152,
bos_token_id=2,
eos_token_id=1,
hidden_size=2048,
image_token_index=257152,
model_type="paligemma",
pad_token_id=0,
projection_dim=2048,
text_config={
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2048,
"intermediate_size": 16384,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_image_tokens": 256,
"num_key_value_heads": 1,
"torch_dtype": "float32",
"vocab_size": 257152,
},
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": 256,
"patch_size": 14,
"projection_dim": 2048,
"projector_hidden_act": "gelu_fast",
"torch_dtype": "float32",
"vision_use_head": False,
},
)
elif isinstance(self.paligemma_config, dict):
# Override Pi0 default config for PaliGemma
if "model_type" not in gemma_expert_config:
paligemma_config["model_type"] = "paligemma"
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
self.paligemma_config = cfg_cls(**paligemma_config)
if gemma_expert_config is None:
# Default config from Pi0
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
attention_bias=False,
attention_dropout=0.0,
bos_token_id=2,
eos_token_id=1,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation="gelu_pytorch_tanh",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=4096,
max_position_embeddings=8192,
model_type="gemma",
num_attention_heads=8,
num_hidden_layers=18,
num_key_value_heads=1,
pad_token_id=0,
rms_norm_eps=1e-06,
rope_theta=10000.0,
torch_dtype="float32",
transformers_version="4.48.1",
use_cache=True,
vocab_size=257152,
)
elif isinstance(self.gemma_expert_config, dict):
# Override Pi0 default config for Gemma Expert
if "model_type" not in gemma_expert_config:
gemma_expert_config["model_type"] = "gemma"
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
super().__init__(**kwargs)
def __post_init__(self):
super().__post_init__()
if self.train_expert_only and not self.freeze_vision_encoder:
raise ValueError(
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
)
if self.attention_implementation not in ["eager", "fa2", "flex"]:
raise ValueError(
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
)
class PaliGemmaWithExpertModel(PreTrainedModel):
config_class = PaliGemmaWithExpertConfig
def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config)
self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_like_physical_intelligence()
self.set_requires_grad()
def set_requires_grad(self):
if self.config.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for params in self.paligemma.vision_tower.parameters():
params.requires_grad = False
if self.config.train_expert_only:
self.paligemma.eval()
for params in self.paligemma.parameters():
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.config.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
if self.config.train_expert_only:
self.paligemma.eval()
def to_bfloat16_like_physical_intelligence(self):
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
params_to_change_dtype = [
"language_model.model.layers",
"gemma_expert.model.layers",
"vision_tower",
"multi_modal",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch.bfloat16)
def embed_image(self, image: torch.Tensor):
# Handle different transformers versions
if hasattr(self.paligemma, "get_image_features"):
return self.paligemma.get_image_features(image)
else:
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
# TODO: break down this huge forward into modules or functions
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
inputs_embeds: List[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
fill_kv_cache: Optional[bool] = None,
):
models = [self.paligemma.language_model, self.gemma_expert.model]
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.paligemma.config.text_config.num_hidden_layers
head_dim = self.paligemma.config.text_config.head_dim
for layer_idx in range(num_layers):
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is None:
continue
layer = models[i].layers[layer_idx]
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=torch.bfloat16)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
query_states = apply_rope(query_states, position_ids)
key_states = apply_rope(key_states, position_ids)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], dim=1
)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_output = att_output.to(dtype=torch.bfloat16)
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
if hidden_states is not None:
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
# TODO: first dropout (by default 0.0)
# first residual
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
# TODO: second dropout (by default 0.0)
# second residual
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
if self.config.attention_implementation == "fa2":
attention_interface = self.flash_attention_forward
elif self.config.attention_implementation == "flex":
attention_interface = flex_attention_forward
else:
attention_interface = self.eager_attention_forward
return attention_interface
def flash_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
raise NotImplementedError("FA2 is not implemented (yet)")
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
big_neg = -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
# value_states: batch_size, sequence_length, num_att_heads, head_dim
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output

View File

@@ -0,0 +1,136 @@
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
@PreTrainedConfig.register_subclass("pi0fast")
@dataclass
class PI0FASTConfig(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 10
n_action_steps: int = 5
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32 # 32
max_action_dim: int = 32 # 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)
interpolate_like_pi: bool = False
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Projector
proj_width: int = 1024
# Decoding
max_decoding_steps: int = 256
fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
max_input_seq_len: int = 256 # 512
# Utils
use_cache: bool = True
# Frozen parameters
freeze_vision_encoder: bool = True
freeze_lm_head: bool = True
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
checkpoint_path: str = None
padding_side: str = "right"
precision: str = "bfloat16"
grad_clip_norm: float = 1
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
relaxed_action_decoding: bool = True
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,982 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
[Paper](https://huggingface.co/papers/2501.09747)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
```bash
python -m lerobot.scripts.train \
--policy.path=lerobot/pi0fast_base \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of training the pi0+FAST neural network with from scratch:
```bash
python -m lerobot.scripts.train \
--policy.type=pi0fast \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of using the pi0 pretrained model outside LeRobot training framework:
```python
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
```
"""
from collections import deque
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from scipy.fft import idct
from torch import Tensor, nn
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
PRECISION = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class PI0FASTPolicy(PreTrainedPolicy):
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
config_class = PI0FASTConfig
name = "pi0fast"
def __init__(
self,
config: PI0FASTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def get_optim_params(self) -> dict:
return self.parameters()
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model.generate_actions(batch)
actions = actions[:, : self.config.n_action_steps]
original_action_dim = self.config.action_feature.shape[
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss_dict = self.model.forward(batch)
return loss_dict["loss"], loss_dict
def block_causal_update_causal_mask(
attention_mask,
token_type_ids=None,
past_key_values=None,
cache_position=None,
input_tensor=None,
attn_implementation: str = "eager",
dtype: torch.dtype = "float32",
):
"""
Update the causal mask during training and generation. It can be customized to different attention masks.
"""
if attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(dtype).min
if input_tensor is None:
input_tensor = attention_mask
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache or isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
# Handle precomputed attention masks
if attention_mask is not None and attention_mask.dim() == 4:
return attention_mask
# Causal mask initialization
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
# Standard causal masking (triu ensures tokens can only attend to past)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
# Apply block causal mask
if token_type_ids is not None:
token_type_ids = token_type_ids.to(causal_mask.device).bool()
cumsum = torch.cumsum(token_type_ids, dim=1)
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
# Combine causal_mask with block-wise attention mask
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
causal_mask = causal_mask[:, None, :, :]
else:
# Apply past cache position constraint
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
else:
# Apply past cache position constraint
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
mask_length = attention_mask.shape[-1]
# Apply padding mask
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def prepare_inputs_for_generation(
# self,
input_ids,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
pixel_values=None,
attention_mask=None,
token_type_ids=None,
use_cache=True,
num_logits_to_keep=None,
labels=None,
self=None,
**kwargs,
):
# create block causal attention
if cache_position[0] > 0 and input_ids.shape[1] > 0:
input_tensor = input_ids[:, -1:]
new_positions = (
torch.ones(
(position_ids.shape[0], input_ids.shape[1]),
dtype=position_ids.dtype,
device=position_ids.device,
).cumsum(-1)
+ position_ids[:, -1:]
)
position_ids = torch.cat([position_ids, new_positions], dim=-1)
else:
input_tensor = inputs_embeds
attention_mask = block_causal_update_causal_mask(
attention_mask=attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
input_tensor=input_tensor,
token_type_ids=token_type_ids,
dtype=self.dtype,
attn_implementation=self.config.text_config._attn_implementation,
)
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=use_cache,
num_logits_to_keep=num_logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
# Position_ids in Paligemma are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
class PI0FAST(nn.Module):
def __init__(self, config: PI0FASTConfig):
super().__init__()
self.config = config
# TODO: move tokenizers in Policy
fast_tokenizer_path = "physical-intelligence/fast"
pi0_paligemma_path = "google/paligemma-3b-pt-224"
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
self.fast_skip_tokens = self.config.fast_skip_tokens
self.max_input_seq_len = self.config.max_input_seq_len
self.action_horizon = self.config.chunk_size
self.action_dim = self.config.action_feature.shape[
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
precision = config.precision
torch_precision = PRECISION.get(precision, torch.float32)
self.pad_token_id = (
self.paligemma_tokenizer.pad_token_id
if hasattr(self.paligemma_tokenizer, "pad_token_id")
else self.paligemma_tokenizer.eos_token_id
)
paligemma_config = CONFIG_MAPPING["paligemma"](
transformers_version="4.48.1",
_vocab_size=257152,
bos_token_id=2,
eos_token_id=1,
hidden_size=2048,
image_token_index=257152,
model_type="paligemma",
pad_token_id=0,
projection_dim=2048,
text_config={
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2048,
"intermediate_size": 16384,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_image_tokens": 256,
"num_key_value_heads": 1,
"torch_dtype": precision,
"vocab_size": 257152,
"_attn_implementation": "eager",
},
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": 256,
"patch_size": 14,
"projection_dim": 2048,
"projector_hidden_act": "gelu_pytorch_tanh",
"torch_dtype": precision,
"vision_use_head": False,
},
)
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
self.pi0_paligemma.prepare_inputs_for_generation = partial(
prepare_inputs_for_generation, self=self.pi0_paligemma
)
# change important stuff in bf16
params_to_change_dtype = [
"language_model",
"vision_tower",
"multi_modal",
]
for name, param in self.pi0_paligemma.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch_precision)
self.set_requires_grad()
self.image_keys = self.config.image_features.keys()
self.ignore_index = self.pi0_paligemma.config.ignore_index
self.padding_side = self.config.padding_side
def set_requires_grad(self):
if self.config.freeze_vision_encoder:
self.pi0_paligemma.vision_tower.eval()
for params in self.pi0_paligemma.vision_tower.parameters():
params.requires_grad = False
# To avoid unused params issue with distributed training
if self.config.freeze_lm_head:
for name, params in self.pi0_paligemma.named_parameters():
if "embed_tokens" in name: # lm heads and embedding layer are tied
params.requires_grad = False
def embed_tokens(self, tokens: torch.Tensor):
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
def prepare_inputs_for_generation(self, *args, **kwargs):
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
def prepare_images(self, batch):
"""Preprocess LeRobot batch into Pi0 inputs"""
images = []
img_masks = []
present_img_keys = [key for key in self.image_keys if key in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
num_empty_cameras = 0
for key in self.image_keys:
if key in present_img_keys:
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(
img,
*self.config.resize_imgs_with_padding,
pad_value=0,
interpolate_like_pi=self.config.interpolate_like_pi,
)
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
else:
if num_empty_cameras >= self.config.empty_cameras:
continue
img = torch.ones_like(img) * -1
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
num_empty_cameras += 1
images.append(img)
img_masks.append(mask)
return images, img_masks
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
return out
def fast_tokenizer_wrapper(self, actions_norm):
"""
A wrapper for self.fast_tokenizer that ensures batch processing,
conversion to PyTorch tensors, and returns a dictionary without padding.
"""
batch_tokens = self.fast_tokenizer(actions_norm)
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
return fast_out
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
# Compute cumulative sum mask
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
# Suffix block (everything after prefix_len)
suffix_mask = cumsum_mask > prefix_len
token_type_ids = suffix_mask
return token_type_ids
def create_input_tokens(self, state, lang_text, actions=None):
bsize = state.shape[0]
device = state.device
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
discretized = torch.bucketize(state, bins) - 1
discretized = discretized[:, :32]
prefix_texts = []
state_text = []
for txt, disc in zip(lang_text, discretized, strict=False):
cleaned = txt.lower().strip().replace("_", " ")
state_str = " ".join(str(val.item()) for val in disc)
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
state_text.append(f"State: {state_str};\n")
prefix_out = self.paligemma_tokenizer(
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
)
prefix_ids = prefix_out["input_ids"].to(device)
prefix_mask = prefix_out["attention_mask"].to(device)
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
if actions is not None:
actions_norm = self.normalize_actions(actions)
actions_pad = F.pad(
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
)[:, :, : self.config.max_action_dim]
fast_out = self.fast_tokenizer_wrapper(
actions_pad.cpu(),
)
act_ids = fast_out["input_ids"]
act_mask = fast_out["attention_mask"].to(device)
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
# Replace action with 0 to pad tokens
act_ids = torch.where(
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
self.pad_token_id,
act_ids,
)
eos_token = torch.tensor(
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
).expand(bsize, -1)
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
act_mask = act_mask.to(device)
else:
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
# Use tokenizer pad function
padded_output = self.paligemma_tokenizer.pad(
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
)
padded_mask = padded_output["attention_mask"]
# define tensor of padding lengths
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
padded_output["padded_mask"] = padded_output.pop("attention_mask")
padded_output["attention_mask"] = att_mask
# loss is computed not on prefix, and not on padding
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
padded_output["token_type_ids"] = token_type_ids
return padded_output
def shift_padding_side(
self,
tokens: torch.Tensor,
ar_mask: torch.Tensor,
padding_mask: torch.Tensor,
loss_mask: torch.Tensor,
targets: torch.Tensor,
token_type_ids: torch.Tensor,
padding_side: str = "right",
) -> tuple[torch.Tensor]:
if padding_side not in ["right", "left"]:
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
new_tokens = torch.empty_like(tokens)
new_ar_masks = torch.empty_like(ar_mask)
new_padding_mask = torch.empty_like(padding_mask)
new_loss_mask = torch.empty_like(loss_mask)
new_targets = torch.empty_like(targets)
new_token_type_ids = torch.empty_like(token_type_ids)
batch_size = tokens.shape[0]
for i in range(batch_size):
padding_indices = torch.where(padding_mask[i] == 0)[0]
non_padding_indices = torch.where(padding_mask[i] == 1)[0]
if padding_side == "left":
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
else:
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
new_tokens[i] = tokens[i].index_select(0, new_indices)
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
new_targets[i] = targets[i].index_select(0, new_indices)
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
def forward(self, batch: dict[str, Tensor]):
device = batch[OBS_STATE].device
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(
state=batch[OBS_STATE],
lang_text=batch["task"],
actions=batch[ACTION],
)
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,
padded_outs["input_ids"],
padded_outs["padded_mask"],
padded_outs["attention_mask"],
padded_outs["loss_mask"],
padded_outs["token_type_ids"],
padding_side=self.padding_side,
)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
token_type_ids = token_type_ids.to(dtype=torch.int64)
past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
pad_masks = block_causal_update_causal_mask(
attention_mask=pad_masks,
past_key_values=None,
cache_position=cache_position,
input_tensor=embs,
token_type_ids=token_type_ids,
dtype=self.pi0_paligemma.dtype,
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
)
outputs = self.pi0_paligemma.forward(
input_ids=None,
token_type_ids=None,
attention_mask=pad_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=embs,
use_cache=False,
labels=None,
)
logits = outputs.logits
loss_fct = nn.CrossEntropyLoss(reduction="none")
# Shift left for next-step prediction
logits = logits[:, :-1, :]
targets = targets[:, 1:].to(device) # Shift targets
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
# Compute per-token loss
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
# Apply loss mask
token_loss = token_loss * loss_mask.reshape(-1)
# Compute final loss
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
# Return loss dictionary
loss_dict = {"ce_loss": loss.item(), "loss": loss}
return loss_dict
def decode_actions_with_fast(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
relaxed_decoding: bool = True,
) -> np.array:
"""
Adapt original decoding in FAST to always return actions instead of zeros.
"""
self.time_horizon = (
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
)
self.action_dim = (
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
)
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert self.time_horizon is not None and self.action_dim is not None, (
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
)
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
if relaxed_decoding:
# Expected sequence length
expected_seq_len = self.time_horizon * self.action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0]
# Apply truncation if too long
if diff < 0:
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
# Apply padding if too short
elif diff > 0:
decoded_dct_coeff = np.pad(
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
)
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert decoded_dct_coeff.shape == (
self.time_horizon,
self.action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
"""
Extracts actions from predicted output tokens using the FAST model.
Args:
tokens (torch.Tensor): The input tensor of tokenized outputs.
action_horizon (int): The number of timesteps for actions.
action_dim (int): The dimensionality of each action.
Returns:
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
"""
# Decode predicted output tokens
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
cleaned_tokens = [
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
for tokens_sequence in decoded_tokens
]
raw_action_tokens = [
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
for sample_tokens in cleaned_tokens
] # something like this should be robust #looks good
action_tokens = [
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
]
# returns the tensor of decoded actions per sample in a list
decoded_actions = [
torch.tensor(
self.decode_actions_with_fast(
tok.tolist(),
time_horizon=action_horizon,
action_dim=action_dim,
relaxed_decoding=self.config.relaxed_action_decoding,
),
device=tokens.device,
).squeeze(0)
for tok in action_tokens
]
return torch.stack(
decoded_actions,
dim=0,
)
def generate_actions(self, batch: dict[str, Tensor]):
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None)
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,
padded_outs["input_ids"],
padded_outs["padded_mask"],
padded_outs["attention_mask"],
padded_outs["loss_mask"],
padded_outs["token_type_ids"],
padding_side="left",
)
token_type_ids = token_type_ids.to(dtype=torch.int64)
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
output_tokens = self.pi0_paligemma.generate(
input_ids=None,
attention_mask=pad_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=embs,
use_cache=self.config.use_cache,
max_new_tokens=self.config.max_decoding_steps,
do_sample=False,
num_beams=1,
token_type_ids=token_type_ids,
)
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
return actions
def embed_image(self, image: torch.Tensor):
# Handle different transformers versions
if hasattr(self.pi0_paligemma, "get_image_features"):
return self.pi0_paligemma.get_image_features(image)
else:
return self.pi0_paligemma.model.get_image_features(image)
def embed_inputs(
self,
images,
img_masks,
tokens,
pad_mask,
ar_mask,
loss_mask,
token_type_ids,
padding_side: str = "right",
):
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
# images are a list of same size
# vectorizing everything!
device = images[0].device
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
all_images = torch.stack(images, dim=1).to(device)
b, n, c, h, w = all_images.shape
all_images = all_images.view(b * n, c, h, w)
embedded = self.embed_image(all_images).to(device)
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
m = b_n // b # Compute the number of images per sample dynamically
# Reshape dynamically
embedded = embedded.view(b, m, p, image_embedding_dim)
tokens_embs = self.embed_tokens(tokens.to(device))
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
num_img_emb = embedded.shape[2]
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
image_target_tokens = (
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
).reshape(b, -1)
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
# Shift pad tokens to the left (.generate()) or right (.train())
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
)
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
if interpolate_like_pi:
img = (img * 255.0).to(dtype=torch.uint8)
img = img.permute(0, 2, 3, 1)
original_device = img.device
img = img.to(device="cpu").numpy()
imgs = []
for sub_img in img:
sub_img = Image.fromarray(sub_img)
resized_img = sub_img.resize((resized_width, resized_height), resample=2)
resized_img = torch.from_numpy(np.array(resized_img))
imgs.append(resized_img)
img = torch.stack(imgs, dim=0)
img = img.permute(0, 3, 1, 2)
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
else:
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img

View File

@@ -0,0 +1,243 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import os
from importlib.resources import files
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Type, TypeVar
import packaging
import safetensors
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="PreTrainedPolicy")
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
Base class for policy models.
"""
config_class: None
name: None
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
super().__init__()
if not isinstance(config, PreTrainedConfig):
raise ValueError(
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
"`PreTrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.config = config
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not getattr(cls, "config_class", None):
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'")
def _save_pretrained(self, save_directory: Path) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
@classmethod
def from_pretrained(
cls: Type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs,
) -> T:
"""
The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are
deactivated). To train it, you should first set it back in training mode with `policy.train()`.
"""
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
policy.to(config.device)
policy.eval()
return policy
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
load_model_as_safetensor(model, model_file, strict=strict)
if map_location != "cpu":
logging.warning(
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Please update safetensors to version 0.4.3 or above for improved performance."
)
model.to(map_location)
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return model
@abc.abstractmethod
def get_optim_params(self) -> dict:
"""
Returns the policy-specific parameters dict to be passed on to the optimizer.
"""
raise NotImplementedError
@abc.abstractmethod
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
raise NotImplementedError
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
@abc.abstractmethod
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""_summary_
Args:
batch (dict[str, Tensor]): _description_
Returns:
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
is a Tensor, all other items should be logging-friendly, native Python types.
"""
raise NotImplementedError
@abc.abstractmethod
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
Child classes using action chunking should use this method within `select_action` to form the action chunk
cached for selection.
"""
raise NotImplementedError
@abc.abstractmethod
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
raise NotImplementedError
def push_model_to_hub(
self,
cfg: TrainPipelineConfig,
):
api = HfApi()
repo_id = api.create_repo(
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
).repo_id
# Push the files to the repo in a single commit
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
)
card.save(str(saved_path / "README.md"))
cfg.save_pretrained(saved_path) # Calls _save_pretrained and stores train config
commit_info = api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=saved_path,
commit_message="Upload policy weights, train config and readme",
allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.md"],
ignore_patterns=["*.tmp", "*.log"],
)
logging.info(f"Model pushed to {commit_info.repo_url.url}")
def generate_model_card(
self, dataset_repo_id: str, model_type: str, license: str | None, tags: List[str] | None
) -> ModelCard:
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
card_data = ModelCardData(
license=license or "apache-2.0",
library_name="lerobot",
pipeline_tag="robotics",
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
model_name=model_type,
datasets=dataset_repo_id,
base_model=base_model,
)
template_card = files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text()
card = ModelCard.from_template(card_data, template_str=template_card)
card.validate()
return card

View File

@@ -0,0 +1,245 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.optim.optimizers import MultiAdamConfig
def is_image_feature(key: str) -> bool:
"""Check if a feature key represents an image feature.
Args:
key: The feature key to check
Returns:
True if the key represents an image feature, False otherwise
"""
return key.startswith(OBS_IMAGE)
@dataclass
class ConcurrencyConfig:
"""Configuration for the concurrency of the actor and learner.
Possible values are:
- "threads": Use threads for the actor and learner.
- "processes": Use processes for the actor and learner.
"""
actor: str = "threads"
learner: str = "threads"
@dataclass
class ActorLearnerConfig:
learner_host: str = "127.0.0.1"
learner_port: int = 50051
policy_parameters_push_frequency: int = 4
queue_get_timeout: float = 2
@dataclass
class CriticNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
final_activation: str | None = None
@dataclass
class ActorNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
@dataclass
class PolicyConfig:
use_tanh_squash: bool = True
std_min: float = 1e-5
std_max: float = 10.0
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac")
@dataclass
class SACConfig(PreTrainedConfig):
"""Soft Actor-Critic (SAC) configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
"""
# Mapping of feature types to normalization modes
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Statistics for normalizing different types of inputs
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
OBS_STATE: {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
ACTION: {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
}
)
# Architecture specifics
# Device to run the model on (e.g., "cuda", "cpu")
device: str = "cpu"
# Device to store the model on
storage_device: str = "cpu"
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
vision_encoder_name: str | None = None
# Whether to freeze the vision encoder during training
freeze_vision_encoder: bool = True
# Hidden dimension size for the image encoder
image_encoder_hidden_dim: int = 32
# Whether to use a shared encoder for actor and critic
shared_encoder: bool = True
# Number of discrete actions, eg for gripper actions
num_discrete_actions: int | None = None
# Dimension of the image embedding pooling
image_embedding_pooling_dim: int = 8
# Training parameter
# Number of steps for online training
online_steps: int = 1000000
# Seed for the online environment
online_env_seed: int = 10000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
offline_buffer_capacity: int = 100000
# Whether to use asynchronous prefetching for the buffers
async_prefetch: bool = False
# Number of steps before learning starts
online_step_before_learning: int = 100
# Frequency of policy updates
policy_update_freq: int = 1
# SAC algorithm parameters
# Discount factor for the SAC algorithm
discount: float = 0.99
# Initial temperature value
temperature_init: float = 1.0
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
utd_ratio: int = 1
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Target entropy for the SAC algorithm
target_entropy: float | None = None
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Network configuration
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for actor-learner architecture
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations
use_torch_compile: bool = True
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration
def get_optimizer_preset(self) -> MultiAdamConfig:
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"actor": {"lr": self.actor_lr},
"critic": {"lr": self.critic_lr},
"temperature": {"lr": self.temperature_lr},
},
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
has_image = any(is_image_feature(key) for key in self.input_features)
has_state = OBS_STATE in self.input_features
if not (has_state or has_image):
raise ValueError(
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
)
if "action" not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def image_features(self) -> list[str]:
return [key for key in self.input_features if is_image_feature(key)]
@property
def observation_delta_indices(self) -> list:
return None
@property
def action_delta_indices(self) -> list:
return None # SAC typically predicts one action at a time
@property
def reward_delta_indices(self) -> None:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,76 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
@PreTrainedConfig.register_subclass(name="reward_classifier")
@dataclass
class RewardClassifierConfig(PreTrainedConfig):
"""Configuration for the Reward Classifier model."""
name: str = "reward_classifier"
num_classes: int = 2
hidden_dim: int = 256
latent_dim: int = 256
image_embedding_pooling_dim: int = 8
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
learning_rate: float = 1e-4
weight_decay: float = 0.01
grad_clip_norm: float = 1.0
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
}
)
@property
def observation_delta_indices(self) -> list | None:
return None
@property
def action_delta_indices(self) -> list | None:
return None
@property
def reward_delta_indices(self) -> list | None:
return None
def get_optimizer_preset(self) -> OptimizerConfig:
return AdamWConfig(
lr=self.learning_rate,
weight_decay=self.weight_decay,
grad_clip_norm=self.grad_clip_norm,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return None
def validate_features(self) -> None:
"""Validate feature configurations."""
has_image = any(key.startswith("observation.image") for key in self.input_features)
if not has_image:
raise ValueError(
"You must provide an image observation (key starting with 'observation.image') in the input features"
)

View File

@@ -0,0 +1,323 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
from torch import Tensor, nn
from lerobot.constants import OBS_IMAGE, REWARD
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self,
logits: Tensor,
probabilities: Tensor | None = None,
hidden_states: Tensor | None = None,
):
self.logits = logits
self.probabilities = probabilities
self.hidden_states = hidden_states
def __repr__(self):
return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class SpatialLearnedEmbeddings(nn.Module):
def __init__(self, height, width, channel, num_features=8):
"""
PyTorch implementation of learned spatial embeddings
Args:
height: Spatial height of input features
width: Spatial width of input features
channel: Number of input channels
num_features: Number of output embedding dimensions
"""
super().__init__()
self.height = height
self.width = width
self.channel = channel
self.num_features = num_features
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
def forward(self, features):
"""
Forward pass for spatial embedding
Args:
features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch
Returns:
Output tensor of shape [B, C*F] or [C*F] if no batch
"""
features = features.last_hidden_state
original_shape = features.shape
if features.dim() == 3:
features = features.unsqueeze(0) # Add batch dim
features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1]
kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F]
# Element-wise multiplication and spatial reduction
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W
# Reshape to combine channel and feature dimensions
output = output.view(output.size(0), -1) # [B, C*F]
# Remove batch dim
if len(original_shape) == 3:
output = output.squeeze(0)
return output
class Classifier(PreTrainedPolicy):
"""Image classifier built on top of a pre-trained encoder."""
name = "reward_classifier"
config_class = RewardClassifierConfig
def __init__(
self,
config: RewardClassifierConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
from transformers import AutoModel
super().__init__(config)
self.config = config
# Initialize normalization (standardized with the policy framework)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Set up encoder
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
self.encoder = encoder.vision_model
self.vision_config = encoder.config.vision_config
else:
self.encoder = encoder
self.vision_config = getattr(encoder, "config", None)
# Model type from config
self.is_cnn = self.config.model_type == "cnn"
# For CNNs, initialize backbone
if self.is_cnn:
self._setup_cnn_backbone()
self._freeze_encoder()
# Extract image keys from input_features
self.image_keys = [
key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE)
]
if self.is_cnn:
self.encoders = nn.ModuleDict()
for image_key in self.image_keys:
encoder = self._create_single_encoder()
self.encoders[image_key] = encoder
self._build_classifier_head()
def _setup_cnn_backbone(self):
"""Set up CNN encoder"""
if hasattr(self.encoder, "fc"):
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
for param in self.encoder.parameters():
param.requires_grad = False
def _create_single_encoder(self):
encoder = nn.Sequential(
self.encoder,
SpatialLearnedEmbeddings(
height=4,
width=4,
channel=self.feature_dim,
num_features=self.config.image_embedding_pooling_dim,
),
nn.Dropout(self.config.dropout_rate),
nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
nn.Tanh(),
)
return encoder
def _build_classifier_head(self) -> None:
"""Initialize the classifier head architecture."""
# Get input dimension based on model type
if self.is_cnn:
input_dim = self.config.latent_dim
else: # Transformer models
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(
self.config.hidden_dim,
1 if self.config.num_classes == 2 else self.config.num_classes,
),
)
def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""
with torch.no_grad():
if self.is_cnn:
# The HF ResNet applies pooling internally
outputs = self.encoders[image_key](x)
return outputs
else: # Transformer models
outputs = self.encoder(x)
return outputs.last_hidden_state[:, 0, :]
def extract_images_and_labels(self, batch: dict[str, Tensor]) -> tuple[list, Tensor]:
"""Extract image tensors and label tensors from batch."""
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
labels = batch[REWARD]
return images, labels
def predict(self, xs: list) -> ClassifierOutput:
"""Forward pass of the classifier for inference."""
encoder_outputs = torch.hstack(
[self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)]
)
logits = self.classifier_head(encoder_outputs)
if self.config.num_classes == 2:
logits = logits.squeeze(-1)
probabilities = torch.sigmoid(logits)
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
# Normalize inputs if needed
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images and labels
images, labels = self.extract_images_and_labels(batch)
# Get predictions
outputs = self.predict(images)
# Calculate loss
if self.config.num_classes == 2:
# Binary classification
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
else:
# Multi-class classification
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
predictions = torch.argmax(outputs.logits, dim=1)
# Calculate accuracy for logging
correct = (predictions == labels).sum().item()
total = labels.size(0)
accuracy = 100 * correct / total
# Return loss and metrics for logging
output_dict = {
"accuracy": accuracy,
"correct": correct,
"total": total,
}
return loss, output_dict
def predict_reward(self, batch, threshold=0.5):
"""Eval method. Returns predicted reward with the decision threshold as argument."""
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images from batch dict
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
if self.config.num_classes == 2:
probs = self.predict(images).probabilities
logging.debug(f"Predicted reward images: {probs}")
return (probs > threshold).float()
else:
return torch.argmax(self.predict(images).probabilities, dim=1)
def get_optim_params(self):
"""Return optimizer parameters for the policy."""
return self.parameters()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
raise NotImplementedError("Reward classifiers do not select actions")
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not produce action chunks.
"""
raise NotImplementedError("Reward classifiers do not predict action chunks")
def reset(self):
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
pass

View File

@@ -0,0 +1,154 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
@PreTrainedConfig.register_subclass("smolvla")
@dataclass
class SmolVLAConfig(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (512, 512)
# Add empty images. Used by smolvla_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = True
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
attention_mode: str = "cross_attn"
prefix_length: int = -1
pad_language_to: str = "longest" # "max_length"
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
num_vlm_layers: int = 16 # Number of layers used in the VLM (first num_vlm_layers layers)
self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
)
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list:
return [0]
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,940 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SmolVLA:
[Paper](https://huggingface.co/papers/2506.01844)
Designed by Hugging Face.
Install smolvla extra dependencies:
```bash
pip install -e ".[smolvla]"
```
Example of finetuning the smolvla pretrained model (`smolvla_base`):
```bash
python -m lerobot.scripts.train \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
and an action expert.
```bash
python -m lerobot.scripts.train \
--policy.type=smolvla \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
Example of using the smolvla pretrained model outside LeRobot training framework:
```python
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
```
"""
import math
import os
import re
from collections import deque
import safetensors
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoProcessor
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import (
Normalize,
Unnormalize,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
populate_queues,
)
from lerobot.utils.utils import get_safe_dtype
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
def canonicalise(k: str) -> str:
"""
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
normalisation-buffer key.
"""
return _VARIANT_RE.sub(".buffer_", k)
def standardise_state_dict(
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
) -> tuple[dict[str, torch.Tensor], list[str]]:
"""
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
• If several variant keys collapse to the same canonical name we keep the
first one and log the collision.
• Returns the new dict + a list of entries that could not be matched.
"""
out, collisions, unmatched = {}, {}, []
for k, v in checkpoint.items():
canon = canonicalise(k)
if canon in ref_keys:
if canon in out: # duplicate after collapsing
collisions.setdefault(canon, []).append(k)
else:
out[canon] = v
else:
unmatched.append(k)
if verbose:
for canon, variants in collisions.items():
print(f"[standardise_state_dict] '{canon}'{variants}")
if unmatched:
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
out.update({k: checkpoint[k] for k in unmatched})
return out, unmatched
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
"""
Renames keys in a checkpoint dictionary based on the given rename string.
Args:
checkpoint (dict): The checkpoint dictionary.
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
Returns:
dict: The modified checkpoint with renamed keys.
"""
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
new_checkpoint = {}
for k, v in checkpoint.items():
for old_key, new_key in rename_dict.items():
if old_key in k:
k = k.replace(old_key, new_key)
new_checkpoint[k] = v
return new_checkpoint
def load_smolvla(
model: torch.nn.Module,
filename: str | os.PathLike,
*,
device: str = "cpu",
checkpoint_keys_mapping: str = "",
) -> torch.nn.Module:
state_dict = safetensors.torch.load_file(filename, device=device)
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
raise RuntimeError(
"SmolVLA %d missing / %d unexpected keys",
len(missing),
len(unexpected),
)
return model
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb
def sample_beta(alpha, beta, bsize, device):
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
return gamma1 / (gamma1 + gamma2)
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def resize_with_pad(img, width, height, pad_value=-1):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] == new_dim:
return vector
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
new_vector[..., :current_dim] = vector
return new_vector
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with smolvla which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by smolvla to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class SmolVLAPolicy(PreTrainedPolicy):
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
config_class = SmolVLAConfig
name = "smolvla"
def __init__(
self,
config: SmolVLAConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
self.model = VLAFlowMatching(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
@classmethod
def _load_as_safetensor(
cls,
model: "SmolVLAPolicy",
model_file: str,
map_location: str,
strict: bool,
):
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return load_smolvla(
model,
model_file,
device=map_location,
checkpoint_keys_mapping="model._orig_mod.//model.",
)
def get_optim_params(self) -> dict:
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
return actions
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
return batch
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
actions = self._get_action_chunk(batch, noise)
return actions
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._queues[ACTION]) == 0:
actions = self._get_action_chunk(batch, noise)
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
# For backward pass
loss = losses.mean()
# For backward pass
loss_dict["loss"] = loss.item()
return loss, loss_dict
def prepare_images(self, batch):
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
"""
images = []
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
if f"{key}_padding_mask" in batch:
mask = batch[f"{key}_padding_mask"].bool()
else:
mask = torch.ones(bsize, dtype=torch.bool, device=device)
images.append(img)
img_masks.append(mask)
# Create image features not present in the batch
# as fully 0 padded images.
for num_empty_cameras in range(len(missing_img_keys)):
if num_empty_cameras >= self.config.empty_cameras:
break
img = torch.ones_like(img) * -1
mask = torch.zeros_like(mask)
images.append(img)
img_masks.append(mask)
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]
if len(tasks) == 1:
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding=self.config.pad_language_to,
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
def prepare_state(self, batch):
"""Pad state"""
state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE]
state = pad_vector(state, self.config.max_state_dim)
return state
def prepare_action(self, batch):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
def pad_tensor(tensor, max_len, pad_value=0):
"""
Efficiently pads a tensor along sequence dimension to match max_len.
Args:
tensor (torch.Tensor): Shape (B, L, ...) or (B, L).
max_len (int): Fixed sequence length.
pad_value (int/float): Value for padding.
Returns:
torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
"""
b, d = tensor.shape[:2]
# Create a padded tensor of max_len and copy the existing values
padded_tensor = torch.full(
(b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, :d] = tensor # Efficient in-place copy
return padded_tensor
class VLAFlowMatching(nn.Module):
"""
SmolVLA
[Paper]()
Designed by Hugging Face.
┌──────────────────────────────┐
│ actions │
│ ▲ │
│ ┌─────────┐ ┌─|────┐ │
│ | │────► │ │ │
│ | │ kv │ │ │
│ | │────► │Action│ │
│ | VLM │cache │Expert│ |
│ │ │────► | │ │
│ │ │ │ │ │
│ └▲──▲───▲─┘ └───▲──┘ |
│ │ | | │ |
│ | | | noise │
│ │ │ state │
│ │ language tokens │
│ image(s) │
└──────────────────────────────┘
"""
def __init__(self, config):
super().__init__()
self.config = config
self.vlm_with_expert = SmolVLMWithExpertModel(
model_id=self.config.vlm_model_name,
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only,
load_vlm_weights=self.config.load_vlm_weights,
attention_mode=self.config.attention_mode,
num_expert_layers=self.config.num_expert_layers,
num_vlm_layers=self.config.num_vlm_layers,
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(
self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
)
self.action_time_mlp_out = nn.Linear(
self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
)
self.set_requires_grad()
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
self.global_image_start_token = torch.tensor(
[self.fake_image_token, self.global_image_token], dtype=torch.long
)
self.add_image_special_tokens = self.config.add_image_special_tokens
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
self.prefix_length = self.config.prefix_length
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
def sample_noise(self, shape, device):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for SmolVLM transformer processing.
"""
embs = []
pad_masks = []
att_masks = []
for _img_idx, (
img,
img_mask,
) in enumerate(zip(images, img_masks, strict=False)):
if self.add_image_special_tokens:
image_start_token = (
self.vlm_with_expert.embed_language_tokens(
self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
image_start_mask = torch.ones_like(
image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
)
att_masks += [0] * (image_start_mask.shape[-1])
embs.append(image_start_token)
pad_masks.append(image_start_mask)
img_emb = self.vlm_with_expert.embed_image(img)
img_emb = img_emb
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
embs.append(img_emb)
pad_masks.append(img_mask)
att_masks += [0] * (num_img_embs)
if self.add_image_special_tokens:
image_end_token = (
self.vlm_with_expert.embed_language_tokens(
self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
image_end_mask = torch.ones_like(
image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
)
embs.append(image_end_token)
pad_masks.append(image_end_mask)
att_masks += [0] * (image_end_mask.shape[1])
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
# Normalize language embeddings
lang_emb_dim = lang_emb.shape[-1]
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
embs.append(lang_emb)
pad_masks.append(lang_masks)
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
state_emb = self.state_proj(state)
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
embs.append(state_emb)
bsize = state_emb.shape[0]
device = state_emb.device
states_seq_len = state_emb.shape[1]
state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1] * (states_seq_len)
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
att_masks = att_masks[None, :]
seq_len = pad_masks.shape[1]
if seq_len < self.prefix_length:
embs = pad_tensor(embs, self.prefix_length, pad_value=0)
pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0)
att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0)
att_masks = att_masks.expand(bsize, -1)
return embs, pad_masks, att_masks
def embed_suffix(self, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
# Fuse timestep + action information using an MLP
action_emb = self.action_in_proj(noisy_actions)
device = action_emb.device
bsize = action_emb.shape[0]
dtype = action_emb.dtype
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep,
self.vlm_with_expert.expert_hidden_size,
self.config.min_period,
self.config.max_period,
device=device,
)
time_emb = time_emb.type(dtype=dtype)
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
action_time_emb = F.silu(action_time_emb) # swish == silu
action_time_emb = self.action_time_mlp_out(action_time_emb)
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] * self.config.chunk_size
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
(_, suffix_out), _ = self.vlm_with_expert.forward(
attention_mask=att_2d_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
fill_kv_cache=False,
)
suffix_out = suffix_out[:, -self.config.chunk_size :]
# Original openpi code, upcast attention output
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
if noise is None:
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
_, past_key_values = self.vlm_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step
x_t += dt * v_t
time += dt
return x_t
def denoise_step(
self,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
outputs_embeds, _ = self.vlm_with_expert.forward(
attention_mask=full_att_2d_masks,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=self.config.use_cache,
fill_kv_cache=False,
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
return v_t

View File

@@ -0,0 +1,550 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import List, Optional
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForImageTextToText,
AutoProcessor,
SmolVLMForConditionalGeneration,
)
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
class SmolVLMWithExpertModel(nn.Module):
def __init__(
self,
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
load_vlm_weights: bool = True,
train_expert_only: bool = True,
freeze_vision_encoder: bool = False,
attention_mode: str = "self_attn",
num_expert_layers: int = -1,
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
):
super().__init__()
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)
config = self.vlm.config
else:
config = AutoConfig.from_pretrained(model_id)
self.vlm = SmolVLMForConditionalGeneration(config=config)
self.processor = AutoProcessor.from_pretrained(model_id)
if num_vlm_layers > 0:
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
self.config = config
# Smaller lm expert
lm_expert_config = copy.deepcopy(config.text_config)
hidden_size = lm_expert_config.hidden_size
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
lm_expert_config.num_hidden_layers = self.num_vlm_layers
if num_expert_layers > 0:
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
)
lm_expert_config.num_hidden_layers = num_expert_layers
self.lm_expert = AutoModel.from_config(lm_expert_config)
self.num_expert_layers = len(self.lm_expert.layers)
self.self_attn_every_n_layers = self_attn_every_n_layers
if "cross" in attention_mode:
# Reshape qkv projections to have the same input dimension as the vlm
for layer_idx in range(len(self.lm_expert.layers)):
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
continue
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
# Remove unused embed_tokens
self.lm_expert.embed_tokens = None
self.num_attention_heads = self.config.text_config.num_attention_heads
self.num_key_value_heads = self.config.text_config.num_key_value_heads
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_mode = attention_mode
self.expert_hidden_size = lm_expert_config.hidden_size
self.set_requires_grad()
def get_vlm_model(self):
return self.vlm.model
def set_requires_grad(self):
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
for params in self.get_vlm_model().vision_model.parameters():
params.requires_grad = False
if self.train_expert_only:
self.vlm.eval()
for params in self.vlm.parameters():
params.requires_grad = False
else:
# To avoid unused params issue with distributed training
last_layers = [self.num_vlm_layers - 1]
if (
self.num_vlm_layers != self.num_expert_layers
and self.num_vlm_layers % self.num_expert_layers == 0
):
last_layers.append(self.num_vlm_layers - 2)
frozen_layers = [
"lm_head",
"text_model.model.norm.weight",
]
for layer in last_layers:
frozen_layers.append(f"text_model.model.layers.{layer}.")
for name, params in self.vlm.named_parameters():
if any(k in name for k in frozen_layers):
params.requires_grad = False
# To avoid unused params issue with distributed training
for name, params in self.lm_expert.named_parameters():
if "lm_head" in name:
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
if self.train_expert_only:
self.vlm.eval()
def embed_image(self, image: torch.Tensor):
patch_attention_mask = None
# Get sequence from the vision encoder
image_hidden_states = (
self.get_vlm_model()
.vision_model(
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
patch_attention_mask=patch_attention_mask,
)
.last_hidden_state
)
# Modality projection & resampling
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
return image_hidden_states
def embed_language_tokens(self, tokens: torch.Tensor):
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
def forward_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
if hidden_states is None or layer is None:
continue
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
seq_len = query_states.shape[1]
if seq_len < position_ids.shape[1]:
_position_ids = position_ids[:, :seq_len]
_attention_mask = attention_mask[:, :seq_len, :seq_len]
else:
_position_ids = position_ids
_attention_mask = attention_mask
attention_mask_ = _attention_mask
position_ids_ = _position_ids
query_states = apply_rope(query_states, position_ids_)
key_states = apply_rope(key_states, position_ids_)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
)
return [att_output], past_key_values
def forward_cross_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
attention_interface = self.get_attention_interface()
att_outputs = []
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
)
if len(inputs_embeds) == 2 and not past_key_values:
# Prefix attention
seq_len = inputs_embeds[0].shape[1]
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
layer = model_layers[0][layer_idx]
hidden_states = layer.input_layernorm(inputs_embeds[0])
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
# B,L,H,D with L sequence length, H number of heads, D head dim
query_states = apply_rope(query_state, position_id)
key_states = apply_rope(key_state, position_id)
att_output = attention_interface(
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_outputs.append(att_output)
else:
expert_position_id = position_ids
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = past_key_values[layer_idx]["key_states"]
value_states = past_key_values[layer_idx]["value_states"]
# Expert
expert_layer = model_layers[1][layer_idx]
if expert_layer is not None:
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
expert_input_shape = expert_hidden_states.shape[:-1]
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
*key_states.shape[:2], -1
)
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
) # k_proj should have same dim as kv
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
*value_states.shape[:2], -1
)
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
)
expert_position_id = (
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
) # start from 0
expert_attention_mask = attention_mask[
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
] # take into account kv
expert_query_states = apply_rope(expert_query_state, expert_position_id)
att_output = attention_interface(
expert_attention_mask,
batch_size,
head_dim,
expert_query_states,
expert_key_states,
expert_value_states,
)
att_outputs.append(att_output)
else:
att_outputs.append(None)
# att_output = att_output.to(dtype=models[i].dtype)
return att_outputs, past_key_values
def get_model_layers(self, models: list) -> list:
vlm_layers = []
expert_layers = []
multiple_of = self.num_vlm_layers // self.num_expert_layers
for i in range(self.num_vlm_layers):
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
expert_layer = None
else:
expert_layer_index = i // multiple_of if multiple_of > 0 else i
expert_layer = models[1].layers[expert_layer_index]
vlm_layers.append(models[0].layers[i])
expert_layers.append(expert_layer)
return [vlm_layers, expert_layers]
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: List[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
fill_kv_cache: Optional[bool] = None,
):
models = [self.get_vlm_model().text_model, self.lm_expert]
model_layers = self.get_model_layers(models)
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.num_vlm_layers
head_dim = self.vlm.config.text_config.head_dim
for layer_idx in range(num_layers):
if (
fill_kv_cache
or "cross" not in self.attention_mode
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
):
att_outputs, past_key_values = self.forward_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
else:
att_outputs, past_key_values = self.forward_cross_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
att_output = (
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
) # in case of self_attn
if hidden_states is not None:
if layer is None:
outputs_embeds.append(hidden_states)
continue
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
att_out = att_output[:, start:end]
out_emb = layer.self_attn.o_proj(att_out)
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end if len(att_outputs) == 1 else 0
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
attention_interface = self.eager_attention_forward
return attention_interface
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.num_attention_heads
num_key_value_heads = self.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
att_weights = att_weights.to(dtype=torch.float32)
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output

View File

@@ -0,0 +1,220 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
@PreTrainedConfig.register_subclass("tdmpc")
@dataclass
class TDMPCConfig(PreTrainedConfig):
"""Configuration class for TDMPCPolicy.
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
Args:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
action repeats in Q-learning or ask your favorite chatbot)
horizon: Horizon for model predictive control.
n_action_steps: Number of action steps to take from the plan given by model predictive control. This
is an alternative to using action repeats. If this is set to more than 1, then we require
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
approach of using multiple steps from the plan is not in the original implementation.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
match the original implementation.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
normalization mode here.
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
latent_dim: Observation's latent embedding dimension.
q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation.
mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy
(π), Q ensemble, and V.
discount: Discount factor (γ) to use for the reinforcement learning formalism.
use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model
(π) for each step.
cem_iterations: Number of iterations for the MPPI/CEM loop in MPC.
max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM.
min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π).
Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM.
n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must
be non-zero.
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
be zero.
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
trajectory values (this is the λ coefficient in eqn 4 of FOWM).
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM.
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
parameters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation.
reward_coeff: Loss weighting coefficient for the reward regression loss.
expectile_weight: Weighting (τ) used in expectile regression for the state value function (V).
v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to
be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do
because v_target is obtained by evaluating the learned state-action value functions (Q) with
in-sample actions that may not be always optimal.
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
value (V) expectile regression loss.
consistency_coeff: Loss weighting coefficient for the consistency loss.
advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage
weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages
are clamped at 100.0.
pi_coeff: Loss weighting coefficient for the action regression loss.
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
current time step.
target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated
as ϕ ← αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the
model being trained.
"""
# Input / output structure.
n_obs_steps: int = 1
n_action_repeats: int = 2
horizon: int = 5
n_action_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ENV": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: int = 32
state_encoder_hidden_dim: int = 256
latent_dim: int = 50
q_ensemble_size: int = 5
mlp_dim: int = 512
# Reinforcement learning.
discount: float = 0.9
# Inference.
use_mpc: bool = True
cem_iterations: int = 6
max_std: float = 2.0
min_std: float = 0.05
n_gaussian_samples: int = 512
n_pi_samples: int = 51
uncertainty_regularizer_coeff: float = 1.0
n_elites: int = 50
elite_weighting_temperature: float = 0.5
gaussian_mean_momentum: float = 0.1
# Training and loss computation.
max_random_shift_ratio: float = 0.0476
# Loss coefficients.
reward_coeff: float = 0.5
expectile_weight: float = 0.9
value_coeff: float = 0.1
consistency_coeff: float = 20.0
advantage_scaling: float = 3.0
pi_coeff: float = 0.5
temporal_decay_coeff: float = 0.5
# Target model.
target_model_momentum: float = 0.995
# Training presets
optimizer_lr: float = 3e-4
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_gaussian_samples <= 0:
raise ValueError(
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
)
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
raise ValueError(
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
"information."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.n_action_steps > 1:
if self.n_action_repeats != 1:
raise ValueError(
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
)
if not self.use_mpc:
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(lr=self.optimizer_lr)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
# There should only be one image key.
if len(self.image_features) > 1:
raise ValueError(
f"{self.__class__.__name__} handles at most one image for now. Got image keys {self.image_features}."
)
if len(self.image_features) > 0:
image_ft = next(iter(self.image_features.values()))
if image_ft.shape[-2] != image_ft.shape[-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
@property
def observation_delta_indices(self) -> list:
return list(range(self.horizon + 1))
@property
def action_delta_indices(self) -> list:
return list(range(self.horizon))
@property
def reward_delta_indices(self) -> None:
return list(range(self.horizon))

View File

@@ -0,0 +1,834 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of Finetuning Offline World Models in the Real World.
The comments in this code may sometimes refer to these references:
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://huggingface.co/papers/2203.04955)
FOWM paper: Finetuning Offline World Models in the Real World (https://huggingface.co/papers/2310.16029)
"""
# ruff: noqa: N806
from collections import deque
from copy import deepcopy
from functools import partial
from typing import Callable
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
class TDMPCPolicy(PreTrainedPolicy):
"""Implementation of TD-MPC learning + inference.
Please note several warnings for this policy.
- Evaluation of pretrained weights created with the original FOWM code
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
process communication to use the xarm environment from FOWM. This is because our xarm
environment uses newer dependencies and does not match the environment in FOWM. See
https://github.com/huggingface/lerobot/pull/103 for implementation details.
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
`lerobot/configs/policy/tdmpc_pusht_keypoints.yaml`.
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
match our xarm environment.
"""
config_class = TDMPCConfig
name = "tdmpc"
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
param.requires_grad = False
self.reset()
def get_optim_params(self) -> dict:
return self.parameters()
def reset(self):
"""
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
called on `env.reset()`
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1)
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
# Remove the time dimensions as it is not handled yet.
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
# NOTE: Order of observations matters here.
encode_keys = []
if self.config.image_features:
encode_keys.append(OBS_IMAGE)
if self.config.env_state_feature:
encode_keys.append(OBS_ENV_STATE)
encode_keys.append(OBS_STATE)
z = self.model.encode({k: batch[k] for k in encode_keys})
if self.config.use_mpc: # noqa: SIM108
actions = self.plan(z) # (horizon, batch, action_dim)
else:
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
# sequence dimension like in the MPC branch.
actions = self.model.pi(z).unsqueeze(0)
actions = torch.clamp(actions, -1, +1)
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
self._queues = populate_queues(self._queues, batch)
# When the action queue is depleted, populate it again by querying the policy.
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
if self.config.n_action_repeats > 1:
for _ in range(self.config.n_action_repeats):
self._queues[ACTION].append(actions[0])
else:
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
self._queues[ACTION].extend(actions[: self.config.n_action_steps])
action = self._queues[ACTION].popleft()
return action
@torch.no_grad()
def plan(self, z: Tensor) -> Tensor:
"""Plan sequence of actions using TD-MPC inference.
Args:
z: (batch, latent_dim,) tensor for the initial state.
Returns:
(horizon, batch, action_dim,) tensor for the planned trajectory of actions.
"""
device = get_device_from_parameters(self)
batch_size = z.shape[0]
# Sample Nπ trajectories from the policy.
pi_actions = torch.empty(
self.config.horizon,
self.config.n_pi_samples,
batch_size,
self.config.action_feature.shape[0],
device=device,
)
if self.config.n_pi_samples > 0:
_z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples)
for t in range(self.config.horizon):
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
# helpful for CEM.
pi_actions[t] = self.model.pi(_z, self.config.min_std)
_z = self.model.latent_dynamics(_z, pi_actions[t])
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
mean[:-1] = self._prev_mean[1:]
std = self.config.max_std * torch.ones_like(mean)
for _ in range(self.config.cem_iterations):
# Randomly sample action trajectories for the gaussian distribution.
std_normal_noise = torch.randn(
self.config.horizon,
self.config.n_gaussian_samples,
batch_size,
self.config.action_feature.shape[0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
_std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
dim=1,
)
)
# Update mean with an exponential moving average, and std with a direct replacement.
mean = (
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
)
std = _std.clamp_(self.config.min_std, self.config.max_std)
# Keep track of the mean for warm-starting subsequent steps.
self._prev_mean = mean
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
return actions
@torch.no_grad()
def estimate_value(self, z: Tensor, actions: Tensor):
"""Estimates the value of a trajectory as per eqn 4 of the FOWM paper.
Args:
z: (batch, latent_dim) tensor of initial latent states.
actions: (horizon, batch, action_dim) tensor of action trajectories.
Returns:
(batch,) tensor of values.
"""
# Initialize return and running discount factor.
G, running_discount = 0, 1
# Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics
# model. Keep track of return.
for t in range(actions.shape[0]):
# We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4
# of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0:
regularization = -(
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
)
else:
regularization = 0
# Estimate the next state (latent) and reward.
z, reward = self.model.latent_dynamics_and_reward(z, actions[t])
# Update the return and running discount.
G += running_discount * (reward + regularization)
running_discount *= self.config.discount
# Add the estimated value of the final state (using the minimum for a conservative estimate).
# Do so by predicting the next action, then taking a minimum over the ensemble of state-action value
# estimators.
# Note: This small amount of added noise seems to help a bit at inference time as observed by success
# metrics over 50 episodes of xarm_lift_medium_replay.
next_action = self.model.pi(z, self.config.min_std) # (batch, action_dim)
terminal_values = self.model.Qs(z, next_action) # (ensemble, batch)
# Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper).
if self.config.q_ensemble_size > 2:
G += (
running_discount
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
0
]
)
else:
G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0:
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
info = {}
# (b, t) -> (t, b)
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch[ACTION] # (t, b, action_dim)
reward = batch[REWARD] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations[OBS_IMAGE] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations[OBS_IMAGE],
)
# Get the current observation for predicting trajectories, and all future observations for use in
# the latent consistency loss and TD loss.
current_observation, next_observations = {}, {}
for k in observations:
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[
OBS_IMAGE if self.config.image_features else OBS_ENV_STATE
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
# Compute various targets with stopgrad.
with torch.no_grad():
# Latent state consistency targets.
z_targets = self.model_target.encode(next_observations)
# State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the
# learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM
# uses a learned state value function: V(z). This means the TD targets only depend on in-sample
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
# future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch).
temporal_loss_coeffs = torch.pow(
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
).unsqueeze(-1)
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
# predicted from the (target model's) observation encoder.
consistency_loss = (
(
temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# `z_targets` depends on the next observation.
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
# rewards.
reward_loss = (
(
temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
q_value_loss = (
(
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
# Compute state value loss as in eqn 3 of FOWM.
diff = v_targets - v_preds
# Expectile loss penalizes:
# - `v_preds < v_targets` with weighting `expectile_weight`
# - `v_preds >= v_targets` with weighting `1 - expectile_weight`
raw_v_value_loss = torch.where(
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
) * (diff**2)
v_value_loss = (
(
temporal_loss_coeffs
* raw_v_value_loss
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation.
with torch.no_grad():
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1]
)
info["advantage"] = advantage[0]
# (t, b)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
# gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to
# multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action
# dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop
# the 0.5 as we instead make a configuration parameter for it (see below where we compute the total
# loss).
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
# other losses.
# TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works
# as well as expected.
pi_loss = (
exp_advantage
* mse
* temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
loss = (
self.config.consistency_coeff * consistency_loss
+ self.config.reward_coeff * reward_loss
+ self.config.value_coeff * q_value_loss
+ self.config.value_coeff * v_value_loss
+ self.config.pi_coeff * pi_loss
)
info.update(
{
"consistency_loss": consistency_loss.item(),
"reward_loss": reward_loss.item(),
"Q_value_loss": q_value_loss.item(),
"V_value_loss": v_value_loss.item(),
"pi_loss": pi_loss.item(),
"sum_loss": loss.item() * self.config.horizon,
}
)
# Undo (b, t) -> (t, b).
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
return loss, info
def update(self):
"""Update the target model's parameters with an EMA step."""
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
class TDMPCTOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
def __init__(self, config: TDMPCConfig):
super().__init__()
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, 1),
)
self._pi = nn.Sequential(
nn.Linear(config.latent_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.action_feature.shape[0]),
)
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.ELU(),
nn.Linear(config.mlp_dim, 1),
)
for _ in range(config.q_ensemble_size)
]
)
self._V = nn.Sequential(
nn.Linear(config.latent_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.ELU(),
nn.Linear(config.mlp_dim, 1),
)
self._init_weights()
def _init_weights(self):
"""Initialize model weights.
Orthogonal initialization for all linear and convolutional layers' weights (apart from final layers
of reward network and Q networks which get zero initialization).
Zero initialization for all linear and convolutional layers' biases.
"""
def _apply_fn(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
gain = nn.init.calculate_gain("relu")
nn.init.orthogonal_(m.weight.data, gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
self.apply(_apply_fn)
for m in [self._reward, *self._Qs]:
assert isinstance(m[-1], nn.Linear), (
"Sanity check. The last linear layer needs 0 initialization on weights."
)
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
return self._encoder(obs)
def latent_dynamics_and_reward(self, z: Tensor, a: Tensor) -> tuple[Tensor, Tensor]:
"""Predict the next state's latent representation and the reward given a current latent and action.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
Returns:
A tuple containing:
- (*, latent_dim) tensor for the next state's latent representation.
- (*,) tensor for the estimated reward.
"""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x).squeeze(-1)
def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor:
"""Predict the next state's latent representation given a current latent and action.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
Returns:
(*, latent_dim) tensor for the next state's latent representation.
"""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)
def pi(self, z: Tensor, std: float = 0.0) -> Tensor:
"""Samples an action from the learned policy.
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
generating rollouts for online training.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
std: The standard deviation of the injected noise.
Returns:
(*, action_dim) tensor for the sampled action.
"""
action = torch.tanh(self._pi(z))
if std > 0:
std = torch.ones_like(action) * std
action += torch.randn_like(action) * std
return action
def V(self, z: Tensor) -> Tensor: # noqa: N802
"""Predict state value (V).
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
Returns:
(*,) tensor of estimated state values.
"""
return self._V(z).squeeze(-1)
def Qs(self, z: Tensor, a: Tensor, return_min: bool = False) -> Tensor: # noqa: N802
"""Predict state-action value for all of the learned Q functions.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
return_min: Set to true for implementing the detail in App. C of the FOWM paper: randomly select
2 of the Qs and return the minimum
Returns:
(q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble OR
(*,) tensor if return_min=True.
"""
x = torch.cat([z, a], dim=-1)
if not return_min:
return torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0)
else:
if len(self._Qs) > 2: # noqa: SIM108
Qs = [self._Qs[i] for i in np.random.choice(len(self._Qs), size=2)]
else:
Qs = self._Qs
return torch.stack([q(x).squeeze(-1) for q in Qs], dim=0).min(dim=0)[0]
class TDMPCObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: TDMPCConfig):
"""
Creates encoders for pixel and/or state modalities.
TODO(alexander-soare): The original work allows for multiple images by concatenating them along the
channel dimension. Re-implement this capability.
"""
super().__init__()
self.config = config
if config.image_features:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
next(iter(config.image_features.values())).shape[0],
config.image_encoder_hidden_dim,
7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
)
if config.robot_state_feature:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
if config.env_state_feature:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
# NOTE: Order of observations matters here.
if self.config.image_features:
feat.append(
flatten_forward_unflatten(
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
)
)
if self.config.env_state_feature:
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE]))
if self.config.robot_state_feature:
feat.append(self.state_enc_layers(obs_dict[OBS_STATE]))
return torch.stack(feat, dim=0).mean(0)
def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
"""Randomly shifts images horizontally and vertically.
Adapted from https://github.com/facebookresearch/drqv2
"""
b, _, h, w = x.size()
assert h == w, "non-square images not handled yet"
pad = int(round(max_random_shift_ratio * h))
x = F.pad(x, tuple([pad] * 4), "replicate")
eps = 1.0 / (h + 2 * pad)
arange = torch.linspace(
-1.0 + eps,
1.0 - eps,
h + 2 * pad,
device=x.device,
dtype=torch.float32,
)[:h]
arange = einops.repeat(arange, "w -> h w 1", h=h)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b)
# A random shift in units of pixels and within the boundaries of the padding.
shift = torch.randint(
0,
2 * pad + 1,
size=(b, 1, 1, 2),
device=x.device,
dtype=torch.float32,
)
shift *= 2.0 / (h + 2 * pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
p_ema.mul_(alpha)
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally
different from *.
Returns:
A return value from the callable reshaped to (**, *).
"""
if image_tensor.ndim == 4:
return fn(image_tensor)
start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@@ -0,0 +1,73 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import torch
from torch import nn
def populate_queues(
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
):
if exclude_keys is None:
exclude_keys = []
for key in batch:
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
# queues have the keys they want).
if key not in queues or key in exclude_keys:
continue
if len(queues[key]) != queues[key].maxlen:
# initialize by copying the first observation several times until the queue is full
while len(queues[key]) != queues[key].maxlen:
queues[key].append(batch[key])
else:
# add latest observation to the queue
queues[key].append(batch[key])
return queues
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Note: assumes that all parameters have the same device
"""
return next(iter(module.parameters())).device
def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
"""Get a module's parameter dtype by checking one of its parameters.
Note: assumes that all parameters have the same dtype.
"""
return next(iter(module.parameters())).dtype
def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
"""
Calculates the output shape of a PyTorch module given an input shape.
Args:
module (nn.Module): a PyTorch module
input_shape (tuple): A tuple representing the input shape, e.g., (batch_size, channels, height, width)
Returns:
tuple: The output shape of the module.
"""
dummy_input = torch.zeros(size=input_shape)
with torch.inference_mode():
output = module(dummy_input)
return tuple(output.shape)

View File

@@ -0,0 +1,200 @@
#!/usr/bin/env python
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import VQBeTSchedulerConfig
@PreTrainedConfig.register_subclass("vqbet")
@dataclass
class VQBeTConfig(PreTrainedConfig):
"""Configuration class for VQ-BeT.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views. Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
action_chunk_size: Action chunk size of each action prediction token.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
n_vqvae_training_steps: Number of optimization steps for training Residual VQ.
vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer).
vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary.
vqvae_enc_hidden_dim: Size of hidden dimensions of Encoder / Decoder part of Residaul VQ-VAE
gpt_block_size: Max block size of minGPT (should be larger than the number of input tokens)
gpt_input_dim: Size of output input of GPT. This is also used as the dimension of observation features.
gpt_output_dim: Size of output dimension of GPT. This is also used as a input dimension of offset / bin prediction headers.
gpt_n_layer: Number of layers of GPT
gpt_n_head: Number of headers of GPT
gpt_hidden_dim: Size of hidden dimensions of GPT
dropout: Dropout rate for GPT
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
offset_loss_weight: A constant that is multiplied to the offset loss
primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
sequentially_select: Whether select code of primary / secondary as sequentially (pick primary code,
and then select secodnary code), or at the same time.
"""
# Inputs / output structure.
n_obs_steps: int = 5
n_action_pred_token: int = 3
action_chunk_size: int = 5
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
# VQ-VAE
n_vqvae_training_steps: int = 20000
vqvae_n_embed: int = 16
vqvae_embedding_dim: int = 256
vqvae_enc_hidden_dim: int = 128
# VQ-BeT
gpt_block_size: int = 500
gpt_input_dim: int = 512
gpt_output_dim: int = 512
gpt_n_layer: int = 8
gpt_n_head: int = 8
gpt_hidden_dim: int = 512
dropout: float = 0.1
mlp_hidden_dim: int = 1024
offset_loss_weight: float = 10000.0
primary_code_loss_weight: float = 5.0
secondary_code_loss_weight: float = 0.5
bet_softmax_temperature: float = 0.1
sequentially_select: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
optimizer_vqvae_lr: float = 1e-3
optimizer_vqvae_weight_decay: float = 1e-4
scheduler_warmup_steps: int = 500
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> VQBeTSchedulerConfig:
return VQBeTSchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
num_vqvae_training_steps=self.n_vqvae_training_steps,
)
def validate_features(self) -> None:
# Note: this check was previously performed inside VQBeTRgbEncoder in the form of
# assert len(image_keys) == 1
if not len(self.image_features) == 1:
raise ValueError("You must provide only one image among the inputs.")
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,915 @@
#!/usr/bin/env python
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import deque
from typing import Callable, List
import einops
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from torch import Tensor, nn
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.vqbet_utils import GPT, ResidualVQ
# ruff: noqa: N806
class VQBeTPolicy(PreTrainedPolicy):
"""
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
"""
config_class = VQBeTConfig
name = "vqbet"
def __init__(
self,
config: VQBeTConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.vqbet = VQBeTModel(config)
self.reset()
def get_optim_params(self) -> dict:
vqvae_params = (
list(self.vqbet.action_head.vqvae_model.encoder.parameters())
+ list(self.vqbet.action_head.vqvae_model.decoder.parameters())
+ list(self.vqbet.action_head.vqvae_model.vq_layer.parameters())
)
decay_params, no_decay_params = self.vqbet.policy.configure_parameters()
decay_params = (
decay_params
+ list(self.vqbet.rgb_encoder.parameters())
+ list(self.vqbet.state_projector.parameters())
+ list(self.vqbet.rgb_feature_projector.parameters())
+ [self.vqbet.action_token]
+ list(self.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
if self.config.sequentially_select:
decay_params = (
decay_params
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
return [
{
"params": decay_params,
},
{
"params": vqvae_params,
"weight_decay": self.config.optimizer_vqvae_weight_decay,
"lr": self.config.optimizer_vqvae_lr,
},
{
"params": no_decay_params,
"weight_decay": 0.0,
},
]
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
queues are populated during rollout of the policy, they contain the n latest observations and actions
"""
self._queues = {
OBS_IMAGES: deque(maxlen=self.config.n_obs_steps),
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
ACTION: deque(maxlen=self.config.action_chunk_size),
}
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if not self.vqbet.action_head.vqvae_model.discretized.item():
warnings.warn(
"To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ.",
stacklevel=1,
)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
# since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
# loss: total loss of training RVQ
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch[ACTION])
)
return loss, {
"n_different_codes": n_different_codes,
"n_different_combinations": n_different_combinations,
"recon_l1_error": recon_l1_error,
}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
_, loss_dict = self.vqbet(batch, rollout=False)
loss = loss_dict.pop("loss")
return loss, loss_dict
class SpatialSoftmax(nn.Module):
"""
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation.
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
-----------------------------------------------------
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
| ... | ... | ... | ... |
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
-----------------------------------------------------
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
linear mapping (in_channels, H, W) -> (num_kp, H, W).
"""
def __init__(self, input_shape, num_kp=None):
"""
Args:
input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
"""
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: Tensor) -> Tensor:
"""
Args:
features: (B, C, H, W) input feature maps.
Returns:
(B, K, 2) image-space coordinates of keypoints.
"""
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class VQBeTModel(nn.Module):
"""VQ-BeT: The underlying neural network for VQ-BeT
Note: In this code we use the terms `rgb_encoder`, 'policy', `action_head`. The meanings are as follows.
- The `rgb_encoder` process rgb-style image observations to one-dimensional embedding vectors
- A `policy` is a minGPT architecture, that takes observation sequences and action query tokens to generate `features`.
- These `features` pass through the action head, which passes through the code prediction, offset prediction head,
and finally generates a prediction for the action chunks.
-------------------------------** legend **-------------------------------
│ n = n_obs_steps, p = n_action_pred_token, c = action_chunk_size) │
│ o_{t} : visual observation at timestep {t}
│ s_{t} : state observation at timestep {t}
│ a_{t} : action at timestep {t}
│ A_Q : action_query_token │
--------------------------------------------------------------------------
Training Phase 1. Discretize action using Residual VQ (for config.n_vqvae_training_steps steps)
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ │ │ │ │ │
│ RVQ encoder │ ─► │ Residual │ ─► │ RVQ Decoder │
│ (a_{t}~a_{t+p}) │ │ Code Quantizer │ │ │
│ │ │ │ │ │
└─────────────────┘ └─────────────────┘ └─────────────────┘
Training Phase 2.
timestep {t-n+1} timestep {t-n+2} timestep {t}
┌─────┴─────┐ ┌─────┴─────┐ ┌─────┴─────┐
o_{t-n+1} o_{t-n+2} ... o_{t}
│ │ │
│ s_{t-n+1} │ s_{t-n+2} ... │ s_{t} p
│ │ │ │ │ │ ┌───────┴───────┐
│ │ A_Q │ │ A_Q ... │ │ A_Q ... A_Q
│ │ │ │ │ │ │ │ │ │
┌───▼─────▼─────▼─────▼─────▼─────▼─────────────────▼─────▼─────▼───────────────▼───┐
│ │
│ GPT │ => policy
│ │
└───────────────▼─────────────────▼─────────────────────────────▼───────────────▼───┘
│ │ │ │
┌───┴───┐ ┌───┴───┐ ┌───┴───┐ ┌───┴───┐
code offset code offset code offset code offset
▼ │ ▼ │ ▼ │ ▼ │ => action_head
RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ RVQ Decoder │
└── + ──┘ └── + ──┘ └── + ──┘ └── + ──┘
▼ ▼ ▼ ▼
action chunk action chunk action chunk action chunk
a_{t-n+1} ~ a_{t-n+2} ~ a_{t} ~ ... a_{t+p-1} ~
a_{t-n+c} a_{t-n+c+1} a_{t+c-1} a_{t+p+c-1}
ONLY this chunk is used in rollout!
"""
def __init__(self, config: VQBeTConfig):
super().__init__()
self.config = config
self.rgb_encoder = VQBeTRgbEncoder(config)
self.num_images = len(self.config.image_features)
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
)
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
)
# GPT part of VQ-BeT
self.policy = GPT(config)
# bin prediction head / offset prediction head part of VQ-BeT
self.action_head = VQBeTHead(config)
# Action tokens for: each observation step, the current action token, and all future action tokens.
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer(
"select_target_actions_indices",
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.images"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch and sequence dims.
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
)
# Arrange prior and current observation step tokens as shown in the class docstring.
# First project features to token dimension.
rgb_tokens = self.rgb_feature_projector(
img_features
) # (batch, obs_step, number of different cameras, projection dims)
input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
input_tokens.append(
self.state_projector(batch["observation.state"])
) # (batch, obs_step, projection dims)
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
# Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2)
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
len_additional_action_token = self.config.n_action_pred_token - 1
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
# add additional action query tokens for predicting future action chunks
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
# get action features (pass through GPT)
features = self.policy(input_tokens)
# len(self.config.input_features) is the number of different observation modes.
# this line gets the index of action prompt tokens.
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
self.config.input_features
)
# only extract the output tokens at the position of action query:
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
# mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://huggingface.co/papers/2206.11251).
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
if len_additional_action_token > 0:
features = torch.cat(
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
)
else:
features = features[:, historical_act_pred_index]
# pass through action head
action_head_output = self.action_head(features)
# if rollout, VQ-BeT don't calculate loss
if rollout:
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
batch_size, self.config.action_chunk_size, -1
)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
output = batch[ACTION][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
return action_head_output, loss
class VQBeTHead(nn.Module):
def __init__(self, config: VQBeTConfig):
"""
VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`)
self.map_to_cbet_preds_bin: outputs probability of each code (for each layer).
The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT,
and the output dimension of `self.map_to_cbet_preds_bin` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed`.
if the agent select the code sequentially, we use self.map_to_cbet_preds_primary_bin and self.map_to_cbet_preds_secondary_bin instead of self._map_to_cbet_preds_bin.
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.action_feature.shape[0]`.
"""
super().__init__()
self.config = config
# init vqvae
self.vqvae_model = VqVae(config)
if config.sequentially_select:
self.map_to_cbet_preds_primary_bin = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[self.config.vqvae_n_embed],
)
self.map_to_cbet_preds_secondary_bin = MLP(
in_channels=config.gpt_output_dim + self.config.vqvae_n_embed,
hidden_channels=[self.config.vqvae_n_embed],
)
else:
self.map_to_cbet_preds_bin = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
)
self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[
self.vqvae_model.vqvae_num_layers
* self.config.vqvae_n_embed
* config.action_chunk_size
* config.action_feature.shape[0],
],
)
# loss
self._focal_loss_fn = FocalLoss(gamma=2.0)
def discretize(self, n_vqvae_training_steps, actions):
# Resize the action sequence data to fit the action chunk size using a sliding window approach.
actions = torch.cat(
[
actions[:, j : j + self.config.action_chunk_size, :]
for j in range(actions.shape[1] + 1 - self.config.action_chunk_size)
],
dim=0,
)
# `actions` is a tensor of shape (new_batch, action_chunk_size, action_dim) where new_batch is the number of possible chunks created from the original sequences using the sliding window.
loss, metric = self.vqvae_model.vqvae_forward(actions)
n_different_codes = sum(
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
)
n_different_combinations = len(torch.unique(metric[2], dim=0))
recon_l1_error = metric[0].detach().cpu().item()
self.vqvae_model.optimized_steps += 1
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
self.vqvae_model.discretized = torch.tensor(True)
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
print("Finished discretizing action data!")
self.vqvae_model.eval()
for param in self.vqvae_model.vq_layer.parameters():
param.requires_grad = False
return loss, n_different_codes, n_different_combinations, recon_l1_error
def forward(self, x, **kwargs) -> dict:
# N is the batch size, and T is number of action query tokens, which are process through same GPT
N, T, _ = x.shape
# we calculate N and T side parallelly. Thus, the dimensions would be
# (batch size * number of action query tokens, action chunk size, action dimension)
x = einops.rearrange(x, "N T WA -> (N T) WA")
# sample offsets
cbet_offsets = self.map_to_cbet_preds_offset(x)
cbet_offsets = einops.rearrange(
cbet_offsets,
"(NT) (G C WA) -> (NT) G C WA",
G=self.vqvae_model.vqvae_num_layers,
C=self.config.vqvae_n_embed,
)
# if self.config.sequentially_select is True, bin prediction head first sample the primary code, and then sample secondary code
if self.config.sequentially_select:
cbet_primary_logits = self.map_to_cbet_preds_primary_bin(x)
# select primary bin first
cbet_primary_probs = torch.softmax(
cbet_primary_logits / self.config.bet_softmax_temperature, dim=-1
)
NT, choices = cbet_primary_probs.shape
sampled_primary_centers = einops.rearrange(
torch.multinomial(cbet_primary_probs.view(-1, choices), num_samples=1),
"(NT) 1 -> NT",
NT=NT,
)
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
torch.cat(
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
axis=1,
)
)
cbet_secondary_probs = torch.softmax(
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
)
sampled_secondary_centers = einops.rearrange(
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
"(NT) 1 -> NT",
NT=NT,
)
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
else:
cbet_logits = self.map_to_cbet_preds_bin(x)
cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
)
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape
sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
"(NT G) 1 -> NT G",
NT=NT,
)
device = get_device_from_parameters(self)
indices = (
torch.arange(NT, device=device).unsqueeze(1),
torch.arange(self.vqvae_model.vqvae_num_layers, device=device).unsqueeze(0),
sampled_centers,
)
# Use advanced indexing to sample the values (Extract the only offsets corresponding to the sampled codes.)
sampled_offsets = cbet_offsets[indices]
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
sampled_offsets = sampled_offsets.sum(dim=1)
with torch.no_grad():
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
# pass the centroids through decoder to get actions.
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
)
# add offset and decoded centroids
predicted_action = decoded_action + sampled_offsets
predicted_action = einops.rearrange(
predicted_action,
"(N T) W A -> N T (W A)",
N=N,
T=T,
W=self.config.action_chunk_size,
)
return {
"cbet_logits": cbet_logits,
"predicted_action": predicted_action,
"sampled_centers": sampled_centers,
"decoded_action": decoded_action,
}
def loss_fn(self, pred, target, **kwargs):
"""
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
predicted_action: predicted action chunk (offset + decoded centroids)
sampled_centers: sampled centroids (code of RVQ)
decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder
NT: batch size * T
T: number of action query tokens, which are process through same GPT
cbet_logits: probability of all codes in each layer
"""
action_seq = target
predicted_action = pred["predicted_action"]
sampled_centers = pred["sampled_centers"]
decoded_action = pred["decoded_action"]
NT = predicted_action.shape[0] * predicted_action.shape[1]
cbet_logits = pred["cbet_logits"]
predicted_action = einops.rearrange(
predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size
)
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
# Now we can compute the loss.
# offset loss is L1 distance between the predicted action and ground truth action
offset_loss = F.l1_loss(action_seq, predicted_action)
# calculate primary code prediction loss
cbet_loss1 = self._focal_loss_fn(
cbet_logits[:, 0, :],
action_bins[:, 0],
)
# calculate secondary code prediction loss
cbet_loss2 = self._focal_loss_fn(
cbet_logits[:, 1, :],
action_bins[:, 1],
)
# add all the prediction loss
cbet_loss = (
cbet_loss1 * self.config.primary_code_loss_weight
+ cbet_loss2 * self.config.secondary_code_loss_weight
)
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
offset_action_error = torch.mean(torch.abs(action_seq - predicted_action))
action_error_max = torch.max(torch.abs(action_seq - predicted_action))
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
loss_dict = {
"loss": loss,
"classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(),
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
"vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(),
"action_mse_error": action_mse_error.detach().cpu().item(),
}
return loss_dict
class VQBeTRgbEncoder(nn.Module):
"""Encode an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
Same with DiffusionRgbEncoder from modeling_diffusion.py
"""
def __init__(self, config: VQBeTConfig):
super().__init__()
# Set up optional preprocessing.
if config.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
# Set up backbone.
backbone_model = getattr(torchvision.models, config.vision_backbone)(
weights=config.pretrained_backbone_weights
)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if config.use_group_norm:
if config.pretrained_backbone_weights:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.image_features`.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, C, H, W) image tensor with pixel values in [0, 1].
Returns:
(B, D) image feature.
"""
# Preprocess: maybe crop (if it was set up in the __init__).
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
else:
# Always use center crop for eval.
x = self.center_crop(x)
# Extract backbone feature.
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
# Final linear layer with non-linearity.
x = self.relu(self.out(x))
return x
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
root_module: The module for which the submodules need to be replaced
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
func: Takes a module as an argument and returns a new module to replace it with.
Returns:
The root module with its submodules replaced.
"""
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
parent_module = root_module.get_submodule(".".join(parents))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
class VqVae(nn.Module):
def __init__(
self,
config: VQBeTConfig,
):
"""
VQ-VAE is composed of three parts: encoder, vq_layer, and decoder.
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
The vq_layer uses residual VQs.
This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1),
as well as functions to help BeT training part in training phase 2.
"""
super().__init__()
self.config = config
# 'discretized' indicates whether the Residual VQ part is trained or not. (After finishing the training, we set discretized=True)
self.register_buffer("discretized", torch.tensor(False))
self.optimized_steps = 0
# we use the fixed number of layers for Residual VQ across all environments.
self.vqvae_num_layers = 2
self.vq_layer = ResidualVQ(
dim=config.vqvae_embedding_dim,
num_quantizers=self.vqvae_num_layers,
codebook_size=config.vqvae_n_embed,
)
self.encoder = MLP(
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
config.vqvae_embedding_dim,
],
)
self.decoder = MLP(
in_channels=config.vqvae_embedding_dim,
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
self.config.action_feature.shape[0] * self.config.action_chunk_size,
],
)
def get_embeddings_from_code(self, encoding_indices):
# This function gets code indices as inputs, and outputs embedding vectors corresponding to the code indices.
with torch.no_grad():
z_embed = self.vq_layer.get_codebook_vector_from_indices(encoding_indices)
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
z_embed = z_embed.sum(dim=0)
return z_embed
def get_action_from_latent(self, latent):
# given latent vector, this function outputs the decoded action.
output = self.decoder(latent)
if self.config.action_chunk_size == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://huggingface.co/papers/2403.03181)
# this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://huggingface.co/papers/2403.03181)
state = einops.rearrange(state, "N T A -> N (T A)")
with torch.no_grad():
state_rep = self.encoder(state)
state_rep_shape = state_rep.shape[:-1]
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
state_vq = state_rep_flat.view(*state_rep_shape, -1)
vq_code = vq_code.view(*state_rep_shape, -1)
vq_loss_state = torch.sum(vq_loss_state)
return state_vq, vq_code
def vqvae_forward(self, state):
# This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://huggingface.co/papers/2403.03181).
state = einops.rearrange(state, "N T A -> N (T A)")
# We start with passing action (or action chunk) at:t+n through the encoder ϕ.
state_rep = self.encoder(state)
state_rep_shape = state_rep.shape[:-1]
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
# The resulting latent embedding vector x = ϕ(at:t+n) is then mapped to an embedding vector in the codebook of the RVQ layers by the nearest neighbor look-up.
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
state_vq = state_rep_flat.view(*state_rep_shape, -1)
vq_code = vq_code.view(*state_rep_shape, -1)
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
vq_loss_state = torch.sum(vq_loss_state)
# Then, the discretized vector zq(x) is reconstructed as ψ(zq(x)) by passing through the decoder ψ.
dec_out = self.decoder(state_vq)
# Calculate L1 reconstruction loss
encoder_loss = (state - dec_out).abs().mean()
# add encoder reconstruction loss and commitment loss
rep_loss = encoder_loss + vq_loss_state * 5
metric = (
encoder_loss.clone().detach(),
vq_loss_state.clone().detach(),
vq_code,
rep_loss.item(),
)
return rep_loss, metric
class FocalLoss(nn.Module):
"""
From https://github.com/notmahi/miniBET/blob/main/behavior_transformer/bet.py
"""
def __init__(self, gamma: float = 0, size_average: bool = True):
super().__init__()
self.gamma = gamma
self.size_average = size_average
def forward(self, input, target):
if len(input.shape) == 3:
N, T, _ = input.shape
logpt = F.log_softmax(input, dim=-1)
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
elif len(input.shape) == 2:
logpt = F.log_softmax(input, dim=-1)
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
pt = logpt.exp()
loss = -1 * (1 - pt) ** self.gamma * logpt
if self.size_average:
return loss.mean()
else:
return loss.sum()
class MLP(torch.nn.Sequential):
def __init__(
self,
in_channels: int,
hidden_channels: List[int],
):
layers = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(torch.nn.Linear(in_dim, hidden_dim))
layers.append(torch.nn.ReLU())
in_dim = hidden_dim
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1]))
super().__init__(*layers)

File diff suppressed because it is too large Load Diff

347
src/lerobot/record.py Normal file
View File

@@ -0,0 +1,347 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Records a dataset. Actions for the robot can be either generated by teleoperation or by a policy.
Example:
```shell
python -m lerobot.record \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \
--robot.id=black \
--dataset.repo_id=aliberts/record-test \
--dataset.num_episodes=2 \
--dataset.single_task="Grab the cube" \
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
# --teleop.type=so100_leader \
# --teleop.port=/dev/tty.usbmodem58760431551 \
# --teleop.id=blue \
# <- Policy optional if you want to record with a policy \
# --policy.path=${HF_USER}/my_policy \
```
"""
import logging
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
import numpy as np
import rerun as rr
from lerobot.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
)
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
make_robot_from_config,
so100_follower,
so101_follower,
)
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
koch_leader,
make_teleoperator_from_config,
so100_leader,
so101_leader,
)
from lerobot.utils.control_utils import (
init_keyboard_listener,
is_headless,
predict_action,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
get_safe_torch_device,
init_logging,
log_say,
)
from lerobot.utils.visualization_utils import _init_rerun
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# Limit the frames per second.
fps: int = 30
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
def __post_init__(self):
if self.single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
@dataclass
class RecordConfig:
robot: RobotConfig
dataset: DatasetRecordConfig
# Whether to control the robot with a teleoperator
teleop: TeleoperatorConfig | None = None
# Whether to control the robot with a policy
policy: PreTrainedConfig | None = None
# Display all cameras on screen
display_data: bool = False
# Use vocal synthesis to read events.
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.teleop is None and self.policy is None:
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
@safe_stop_image_writer
def record_loop(
robot: Robot,
events: dict,
fps: int,
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | None = None,
policy: PreTrainedPolicy | None = None,
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
):
if dataset is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
# if policy is given it needs cleaning up
if policy is not None:
policy.reset()
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
observation = robot.get_observation()
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
if policy is not None:
action_values = predict_action(
observation_frame,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
elif policy is None and teleop is not None:
action = teleop.get_action()
else:
logging.info(
"No policy or teleoperator provided, skipping action generation."
"This is likely to happen when resetting the environment without a teleop device."
"The robot won't be at its rest position at the start of the next episode."
)
continue
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
sent_action = robot.send_action(action)
if dataset is not None:
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame}
dataset.add_frame(frame, task=single_task)
if display_data:
for obs, val in observation.items():
if isinstance(val, float):
rr.log(f"observation.{obs}", rr.Scalar(val))
elif isinstance(val, np.ndarray):
rr.log(f"observation.{obs}", rr.Image(val), static=True)
for act, val in action.items():
if isinstance(val, float):
rr.log(f"action.{act}", rr.Scalar(val))
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
timestamp = time.perf_counter() - start_episode_t
@parser.wrap()
def record(cfg: RecordConfig) -> LeRobotDataset:
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
_init_rerun(session_name="recording")
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
dataset_features = {**action_features, **obs_features}
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
)
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy)
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
robot.connect()
if teleop is not None:
teleop.connect()
listener, events = init_keyboard_listener()
recorded_episodes = 0
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
fps=cfg.dataset.fps,
teleop=teleop,
policy=policy,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
fps=cfg.dataset.fps,
teleop=teleop,
control_time_s=cfg.dataset.reset_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
log_say("Stop recording", cfg.play_sounds, blocking=True)
robot.disconnect()
if teleop is not None:
teleop.disconnect()
if not is_headless() and listener is not None:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
log_say("Exiting", cfg.play_sounds)
return dataset
if __name__ == "__main__":
record()

102
src/lerobot/replay.py Normal file
View File

@@ -0,0 +1,102 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Replays the actions of an episode from a dataset on a robot.
Example:
```shell
python -m lerobot.replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--dataset.repo_id=aliberts/record-test \
--dataset.episode=2
```
"""
import logging
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
import draccus
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
make_robot_from_config,
so100_follower,
so101_follower,
)
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
log_say,
)
@dataclass
class DatasetReplayConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# Episode to replay.
episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int = 30
@dataclass
class ReplayConfig:
robot: RobotConfig
dataset: DatasetReplayConfig
# Use vocal synthesis to read events.
play_sounds: bool = True
@draccus.wrap()
def replay(cfg: ReplayConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
actions = dataset.hf_dataset.select_columns("action")
robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action_array = actions[idx]["action"]
action = {}
for i, name in enumerate(dataset.features["action"]["names"]):
action[name] = action_array[i]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / dataset.fps - dt_s)
robot.disconnect()
if __name__ == "__main__":
replay()

View File

@@ -0,0 +1,3 @@
from .config import RobotConfig
from .robot import Robot
from .utils import make_robot_from_config

View File

@@ -0,0 +1,40 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import dataclass
from pathlib import Path
import draccus
@dataclass(kw_only=True)
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
# Allows to distinguish between different robots of the same type
id: str | None = None
# Directory to store calibration file
calibration_dir: Path | None = None
def __post_init__(self):
if hasattr(self, "cameras") and self.cameras:
for _, config in self.cameras.items():
for attr in ["width", "height", "fps"]:
if getattr(config, attr) is None:
raise ValueError(
f"Specifying '{attr}' is required for the camera to be used in a robot"
)
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)

View File

@@ -0,0 +1,2 @@
from .config_koch_follower import KochFollowerConfig
from .koch_follower import KochFollower

View File

@@ -0,0 +1,39 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("koch_follower")
@dataclass
class KochFollowerConfig(RobotConfig):
# Port to connect to the arm
port: str
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False

View File

@@ -0,0 +1,258 @@
# Koch v1.1
In the steps below, we explain how to assemble the Koch v1.1 robot.
## Order and assemble the parts
Follow the sourcing and assembling instructions provided in this [README](https://github.com/jess-moss/koch-v1-1). This will guide you through setting up both the follower and leader arms, as shown in the image below.
For a visual walkthrough of the assembly process, you can refer to [this video tutorial](https://youtu.be/8nQIg9BwwTk).
> [!WARNING]
> Since the production of this video, we simplified the configuration phase. Because of this, two things differ from the instructions in that video:
> - Don't plug in all the motor cables right away and wait to be instructed to do so in [Configure the motors](#configure-the-motors).
> - Don't screw in the controller board (PCB) to the base right away and wait for being instructed to do so in [Configure the motors](#configure-the-motors).
## Install LeRobot 🤗
To install LeRobot follow, our [Installation Guide](./installation)
In addition to these instructions, you need to install the Dynamixel SDK:
```bash
pip install -e ".[dynamixel]"
```
## Configure the motors
### 1. Find the USB ports associated with each arm
To find the port for each bus servo adapter, run this script:
```bash
python -m lerobot.find_port
```
<hfoptions id="example">
<hfoption id="Mac">
Example output:
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
Remove the USB cable from your MotorsBus and press Enter when done.
[...Disconnect corresponding leader or follower arm and press Enter...]
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the USB cable.
```
Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
</hfoption>
<hfoption id="Linux">
On Linux, you might need to give access to the USB ports by running:
```bash
sudo chmod 666 /dev/ttyACM0
sudo chmod 666 /dev/ttyACM1
```
Example output:
```
Finding all available ports for the MotorBus.
['/dev/ttyACM0', '/dev/ttyACM1']
Remove the usb cable from your MotorsBus and press Enter when done.
[...Disconnect corresponding leader or follower arm and press Enter...]
The port of this MotorsBus is /dev/ttyACM1
Reconnect the USB cable.
```
Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm.
</hfoption>
</hfoptions>
### 2. Set the motors ids and baudrates
Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate.
To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once.
If you are repurposing motors from another robot, you will probably also need to perform this step, as the ids and baudrate likely won't match.
#### Follower
Connect the usb cable from your computer and the 5V power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter.
For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm.
<hfoptions id="setup_motors">
<hfoption id="Command">
```bash
python -m lerobot.setup_motors \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
config = KochFollowerConfig(
port="/dev/tty.usbmodem575E0031751",
id="my_awesome_follower_arm",
)
follower = KochFollower(config)
follower.setup_motors()
```
</hfoption>
</hfoptions>
You should see the following instruction.
```
Connect the controller board to the 'gripper' motor only and press enter.
```
As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor.
<details>
<summary>Troubleshooting</summary>
If you get an error at that point, check your cables and make sure they are plugged in properly:
<ul>
<li>Power supply</li>
<li>USB cable between your computer and the controller board</li>
<li>The 3-pin cable from the controller board to the motor</li>
</ul>
If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB).
</details>
You should then see the following message:
```
'gripper' motor id set to 6
```
Followed by the next instruction:
```
Connect the controller board to the 'wrist_roll' motor only and press enter.
```
You can disconnect the 3-pin cable from the controller board but you can leave it connected to the gripper motor on the other end as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one.
Repeat the operation for each motor as instructed.
> [!TIP]
> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board.
When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm.
#### Leader
Do the same steps for the leader arm but modify the command or script accordingly.
<hfoptions id="setup_motors">
<hfoption id="Command">
```bash
python -m lerobot.setup_motors \
--teleop.type=koch_leader \
--teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
config = KochLeaderConfig(
port="/dev/tty.usbmodem575E0031751",
id="my_awesome_leader_arm",
)
leader = KochLeader(config)
leader.setup_motors()
```
</hfoption>
</hfoptions>
## Calibrate
Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
The calibration process is very important because it allows a neural network trained on one robot to work on another.
#### Follower
Run the following command or API example to calibrate the follower arm:
<hfoptions id="calibrate_follower">
<hfoption id="Command">
```bash
python -m lerobot.calibrate \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.robots.koch_follower import KochFollowerConfig, KochFollower
config = KochFollowerConfig(
port="/dev/tty.usbmodem585A0076891",
id="my_awesome_follower_arm",
)
follower = KochFollower(config)
follower.connect(calibrate=False)
follower.calibrate()
follower.disconnect()
```
</hfoption>
</hfoptions>
We unified the calibration method for most robots. Thus, the calibration steps for this Koch arm are the same as the steps for the SO100 and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video).
#### Leader
Do the same steps to calibrate the leader arm, run the following command or API example:
<hfoptions id="calibrate_leader">
<hfoption id="Command">
```bash
python -m lerobot.calibrate \
--teleop.type=koch_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.teleoperators.koch_leader import KochLeaderConfig, KochLeader
config = KochLeaderConfig(
port="/dev/tty.usbmodem575E0031751",
id="my_awesome_leader_arm",
)
leader = KochLeader(config)
leader.connect(calibrate=False)
leader.calibrate()
leader.disconnect()
```
</hfoption>
</hfoptions>
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
> [!TIP]
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).

View File

@@ -0,0 +1,230 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
OperatingMode,
)
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_koch_follower import KochFollowerConfig
logger = logging.getLogger(__name__)
class KochFollower(Robot):
"""
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow
expansion, developed by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
"""
config_class = KochFollowerConfig
name = "koch_follower"
def __init__(self, config: KochFollowerConfig):
super().__init__(config)
self.config = config
norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100
self.bus = DynamixelMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(1, "xl430-w250", norm_mode_body),
"shoulder_lift": Motor(2, "xl430-w250", norm_mode_body),
"elbow_flex": Motor(3, "xl330-m288", norm_mode_body),
"wrist_flex": Motor(4, "xl330-m288", norm_mode_body),
"wrist_roll": Motor(5, "xl330-m288", norm_mode_body),
"gripper": Motor(6, "xl330-m288", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
def connect(self, calibrate: bool = True) -> None:
"""
We assume that at connection time, arm is in a rest position,
and torque can be safely disabled to run calibration.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
self.calibrate()
for cam in self.cameras.values():
cam.connect()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
def calibrate(self) -> None:
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
input(f"Move {self} to the middle of its range of motion and press ENTER....")
homing_offsets = self.bus.set_half_turn_homings()
full_turn_motors = ["shoulder_pan", "wrist_roll"]
unknown_range_motors = [motor for motor in self.bus.motors if motor not in full_turn_motors]
print(
f"Move all joints except {full_turn_motors} sequentially through their entire "
"ranges of motion.\nRecording positions. Press ENTER to stop..."
)
range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors)
for motor in full_turn_motors:
range_mins[motor] = 0
range_maxes[motor] = 4095
self.calibration = {}
for motor, m in self.bus.motors.items():
self.calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=homing_offsets[motor],
range_min=range_mins[motor],
range_max=range_maxes[motor],
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
with self.bus.torque_disabled():
self.bus.configure_motors()
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point
for motor in self.bus.motors:
if motor != "gripper":
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
# Use 'position control current based' for gripper to be limited by the limit of the current. For
# the follower gripper, it means it can grasp an object without forcing too much even tho, its
# goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with
# our finger to make it move, and it will move back to its original target position when we
# release the force.
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
# Set better PID values to close the gap between recorded states and actions
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
self.bus.write("Position_P_Gain", "elbow_flex", 1500)
self.bus.write("Position_I_Gain", "elbow_flex", 0)
self.bus.write("Position_D_Gain", "elbow_flex", 600)
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position")
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def send_action(self, action: dict[str, float]) -> dict[str, float]:
"""Command arm to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
`max_relative_target`. In this case, the action sent differs from original action.
Thus, this function always returns the action actually sent.
Args:
action (dict[str, float]): The goal positions for the motors.
Returns:
dict[str, float]: The action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
# Cap goal position when too far away from present position.
# /!\ Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.bus.sync_read("Present_Position")
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
# Send goal position to the arm
self.bus.sync_write("Goal_Position", goal_pos)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -0,0 +1,3 @@
from .config_lekiwi import LeKiwiClientConfig, LeKiwiConfig
from .lekiwi import LeKiwi
from .lekiwi_client import LeKiwiClient

View File

@@ -0,0 +1,96 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras.configs import CameraConfig, Cv2Rotation
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from ..config import RobotConfig
def lekiwi_cameras_config() -> dict[str, CameraConfig]:
return {
"front": OpenCVCameraConfig(
index_or_path="/dev/video0", fps=30, width=640, height=480, rotation=Cv2Rotation.ROTATE_180
),
"wrist": OpenCVCameraConfig(
index_or_path="/dev/video2", fps=30, width=480, height=640, rotation=Cv2Rotation.ROTATE_90
),
}
@RobotConfig.register_subclass("lekiwi")
@dataclass
class LeKiwiConfig(RobotConfig):
port: str = "/dev/ttyACM0" # port to connect to the bus
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False
@dataclass
class LeKiwiHostConfig:
# Network Configuration
port_zmq_cmd: int = 5555
port_zmq_observations: int = 5556
# Duration of the application
connection_time_s: int = 30
# Watchdog: stop the robot if no command is received for over 0.5 seconds.
watchdog_timeout_ms: int = 500
# If robot jitters decrease the frequency and monitor cpu load with `top` in cmd
max_loop_freq_hz: int = 30
@RobotConfig.register_subclass("lekiwi_client")
@dataclass
class LeKiwiClientConfig(RobotConfig):
# Network Configuration
remote_ip: str
port_zmq_cmd: int = 5555
port_zmq_observations: int = 5556
teleop_keys: dict[str, str] = field(
default_factory=lambda: {
# Movement
"forward": "w",
"backward": "s",
"left": "a",
"right": "d",
"rotate_left": "z",
"rotate_right": "x",
# Speed control
"speed_up": "r",
"speed_down": "f",
# quit teleop
"quit": "q",
}
)
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
polling_timeout_ms: int = 15
connect_timeout_s: int = 5

View File

@@ -0,0 +1,300 @@
# LeKiwi
In the steps below, we explain how to assemble the LeKiwi mobile robot.
## Source the parts
Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts.
And advise if it's your first time printing or if you don't own a 3D printer.
### Wired version
If you have the **wired** LeKiwi version, you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating.
## Install software on Pi
Now we have to set up the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board.
### Install OS
For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu.
### Setup SSH
After setting up your Pi, you should enable and set up [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can log in to the Pi from your laptop without requiring a screen, keyboard, and mouse on the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or, if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension.
### Install LeRobot on Pi 🤗
On your Raspberry Pi install LeRobot using our [Installation Guide](./installation)
In addition to these instructions, you need to install the Feetech sdk on your Pi:
```bash
pip install -e ".[feetech]"
```
## Install LeRobot locally
If you already have installed LeRobot on your laptop/pc you can skip this step; otherwise, please follow along as we do the same steps we did on the Pi.
Follow our [Installation Guide](./installation)
Great :hugs:! You are now done installing LeRobot, and we can begin assembling the SO100/SO101 arms and the mobile base :robot:.
Every time you now want to use LeRobot, you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
# Step-by-Step Assembly Instructions
First, we will assemble the two SO100/SO101 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base. The instructions for assembling can be found on these two pages:
- [Assemble SO101](./so101#step-by-step-assembly-instructions)
- [Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi/blob/main/Assembly.md)
### Find the USB ports associated with motor board
To find the port for each bus servo adapter, run this script:
```bash
python -m lerobot.find_port
```
<hfoptions id="example">
<hfoption id="Mac">
Example output:
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081']
Remove the USB cable from your MotorsBus and press Enter when done.
[...Disconnect corresponding leader or follower arm and press Enter...]
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the USB cable.
```
Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your board.
</hfoption>
<hfoption id="Linux">
On Linux, you might need to give access to the USB ports by running:
```bash
sudo chmod 666 /dev/ttyACM0
sudo chmod 666 /dev/ttyACM1
```
Example output:
```
Finding all available ports for the MotorBus.
['/dev/ttyACM0']
Remove the usb cable from your MotorsBus and press Enter when done.
[...Disconnect corresponding leader or follower arm and press Enter...]
The port of this MotorsBus is /dev/ttyACM0
Reconnect the USB cable.
```
Where the found port is: `/dev/ttyACM0` corresponding to your board.
</hfoption>
</hfoptions>
### Configure motors
The instructions for configuring the motors can be found in the SO101 [docs](./so101#configure-the-motors). Besides the ids for the arm motors, we also need to set the motor ids for the mobile base. These need to be in a specific order to work. Below an image of the motor ids and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ids for the wheels are 7, 8 and 9.
You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7)
```bash
python -m lerobot.setup_motors \
--robot.type=lekiwi \
--robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/motor_ids.webp" alt="Motor ID's for mobile robot" title="Motor ID's for mobile robot" width="60%">
### Troubleshoot communication
If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue.
#### 1. Verify IP Address Configuration
Make sure that the correct IP for the Pi is used in the commands or in your code. To check the Raspberry Pi's IP address, run (on the Pi command line):
```bash
hostname -I
```
#### 2. Check if Pi is reachable from laptop/pc
Try pinging the Raspberry Pi from your laptop:
```bach
ping <your_pi_ip_address>
```
If the ping fails:
- Ensure the Pi is powered on and connected to the same network.
- Check if SSH is enabled on the Pi.
#### 3. Try SSH connection
If you can't SSH into the Pi, it might not be properly connected. Use:
```bash
ssh <your_pi_user_name>@<your_pi_ip_address>
```
If you get a connection error:
- Ensure SSH is enabled on the Pi by running:
```bash
sudo raspi-config
```
Then navigate to: **Interfacing Options -> SSH** and enable it.
### Calibration
Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated.
The calibration process is very important because it allows a neural network trained on one robot to work on another.
### Calibrate follower arm (on mobile base)
Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm:
```bash
python -m lerobot.calibrate \
--robot.type=lekiwi \
--robot.id=my_awesome_kiwi # <- Give the robot a unique name
```
We unified the calibration method for most robots, thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video).
### Wired version
If you have the **wired** LeKiwi version, please run all commands on your laptop.
### Calibrate leader arm
Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the following command of API example on your laptop:
<hfoptions id="calibrate_leader">
<hfoption id="Command">
```bash
python -m lerobot.calibrate \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader
config = SO100LeaderConfig(
port="/dev/tty.usbmodem58760431551",
id="my_awesome_leader_arm",
)
leader = SO100Leader(config)
leader.connect(calibrate=False)
leader.calibrate()
leader.disconnect()
```
</hfoption>
</hfoptions>
## Teleoperate LeKiwi
> [!TIP]
> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
To teleoperate, SSH into your Raspberry Pi, and run `conda activate lerobot` and this command:
```bash
python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi
```
Then on your laptop, also run `conda activate lerobot` and run the API example, make sure you set the correct `remote_ip` and `port` in `examples/lekiwi/teleoperate.py`.
```bash
python examples/lekiwi/teleoperate.py
```
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
| ---------- | ------------------ | ---------------------- |
| Fast | 0.4 | 90 |
| Medium | 0.25 | 60 |
| Slow | 0.1 | 30 |
| Key | Action |
| --- | -------------- |
| W | Move forward |
| A | Move left |
| S | Move backward |
| D | Move right |
| Z | Turn left |
| X | Turn right |
| R | Increase speed |
| F | Decrease speed |
> [!TIP]
> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiConfig`](../src/lerobot/robot_devices/robots/configs.py).
### Wired version
If you have the **wired** LeKiwi version, please run all commands on your laptop.
## Record a dataset
Once you're familiar with teleoperation, you can record your first dataset.
We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens).
Add your token to the CLI by running this command:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Now you can record a dataset. To record episodes and upload your dataset to the hub, execute this API example tailored for LeKiwi. Make sure to first adapt the `remote_ip`, `repo_id`, `port` and `task` in the script. If you would like to run the script for longer you can increase `NB_CYCLES_CLIENT_CONNECTION`.
```bash
python examples/lekiwi/record.py
```
#### Dataset upload
Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running:
```bash
echo https://huggingface.co/datasets/${HF_USER}/so101_test
```
Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot).
#### Tips for gathering data
Once you're comfortable with data recording, you can create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Also make sure the object you are manipulating is visible on the camera's. A good rule of thumb is you should be able to do the task yourself by only looking at the camera images.
In the following sections, youll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions.
Avoid adding too much variation too quickly, as it may hinder your results.
If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset.
#### Troubleshooting:
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
## Replay an episode
To replay an episode run the API example below, make sure to change `remote_ip`, `port`, LeRobotDatasetId and episode index.
```bash
python examples/lekiwi/replay.py
```
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
## Evaluate your policy
To evaluate your policy run the `evaluate.py` API example, make sure to change `remote_ip`, `port`, model..
```bash
python examples/lekiwi/evaluate.py
```
> [!TIP]
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).

Some files were not shown because too many files have changed in this diff Show More