forked from tangger/lerobot
Port HIL SERL (#644)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Eugene Mironov <helper2424@gmail.com> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com> Co-authored-by: Ke Wang <superwk1017@gmail.com> Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com> Co-authored-by: imstevenpmwork <steven.palma@huggingface.co> Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
This commit is contained in:
@@ -22,6 +22,7 @@ OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
|
||||
ROBOTS = "robots"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
@@ -14,10 +14,13 @@
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.robots import RobotConfig
|
||||
from lerobot.common.teleoperators.config import TeleoperatorConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@@ -155,3 +158,116 @@ class XarmEnv(EnvConfig):
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvTransformConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
control_mode: str = "gamepad"
|
||||
display_cameras: bool = False
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None
|
||||
resize_size: Optional[tuple[int, int]] = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Optional[Any] = None
|
||||
reset_time_s: float = 5.0
|
||||
use_gripper: bool = True
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
|
||||
robot: Optional[RobotConfig] = None
|
||||
teleop: Optional[TeleoperatorConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
task: str = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("hil")
|
||||
@dataclass
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
type: str = "hil"
|
||||
name: str = "PandaPickCube"
|
||||
task: str = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
state_dim: int = 18
|
||||
action_dim: int = 4
|
||||
fps: int = 100
|
||||
episode_length: int = 100
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
robot_config: Optional[RobotConfig] = None
|
||||
teleop_config: Optional[TeleoperatorConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
############################
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"use_viewer": self.use_viewer,
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "hil":
|
||||
return HILEnvConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
@@ -47,6 +47,10 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
|
||||
# This is the case for human-in-the-loop RL where there is only one environment.
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
@@ -62,13 +66,18 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return_observations[imgkey] = img
|
||||
|
||||
if "environment_state" in observations:
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
observations["environment_state"]
|
||||
).float()
|
||||
env_state = torch.from_numpy(observations["environment_state"]).float()
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
|
||||
return_observations["observation.environment_state"] = env_state
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations["observation.state"] = agent_pos
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
483
lerobot/common/model/kinematics.py
Normal file
483
lerobot/common/model/kinematics.py
Normal file
@@ -0,0 +1,483 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
|
||||
def skew_symmetric(w: NDArray[np.float32]) -> NDArray[np.float32]:
|
||||
"""Creates the skew-symmetric matrix from a 3D vector."""
|
||||
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
|
||||
|
||||
|
||||
def rodrigues_rotation(w: NDArray[np.float32], theta: float) -> NDArray[np.float32]:
|
||||
"""Computes the rotation matrix using Rodrigues' formula."""
|
||||
w_hat = skew_symmetric(w)
|
||||
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
|
||||
|
||||
def screw_axis_to_transform(s: NDArray[np.float32], theta: float) -> NDArray[np.float32]:
|
||||
"""Converts a screw axis to a 4x4 transformation matrix."""
|
||||
screw_axis_rot = s[:3]
|
||||
screw_axis_trans = s[3:]
|
||||
|
||||
# Pure translation
|
||||
if np.allclose(screw_axis_rot, 0) and np.linalg.norm(screw_axis_trans) == 1:
|
||||
transform = np.eye(4)
|
||||
transform[:3, 3] = screw_axis_trans * theta
|
||||
|
||||
# Rotation (and potentially translation)
|
||||
elif np.linalg.norm(screw_axis_rot) == 1:
|
||||
w_hat = skew_symmetric(screw_axis_rot)
|
||||
rot_mat = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
t = (
|
||||
np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat
|
||||
) @ screw_axis_trans
|
||||
transform = np.eye(4)
|
||||
transform[:3, :3] = rot_mat
|
||||
transform[:3, 3] = t
|
||||
else:
|
||||
raise ValueError("Invalid screw axis parameters")
|
||||
return transform
|
||||
|
||||
|
||||
def pose_difference_se3(pose1: NDArray[np.float32], pose2: NDArray[np.float32]) -> NDArray[np.float32]:
|
||||
"""
|
||||
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
|
||||
SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space,
|
||||
combining rotation (SO(3)) and translation.
|
||||
|
||||
Each 4x4 matrix has the following structure:
|
||||
[R11 R12 R13 tx]
|
||||
[R21 R22 R23 ty]
|
||||
[R31 R32 R33 tz]
|
||||
[ 0 0 0 1]
|
||||
|
||||
where R is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector.
|
||||
|
||||
Args:
|
||||
pose1: A 4x4 numpy array representing the first pose.
|
||||
pose2: A 4x4 numpy array representing the second pose.
|
||||
|
||||
Returns:
|
||||
A 6D numpy array concatenating translation and rotation differences.
|
||||
First 3 elements are the translational difference (position).
|
||||
Last 3 elements are the rotational difference in axis-angle representation.
|
||||
"""
|
||||
rot1 = pose1[:3, :3]
|
||||
rot2 = pose2[:3, :3]
|
||||
|
||||
translation_diff = pose1[:3, 3] - pose2[:3, 3]
|
||||
|
||||
# Calculate rotational difference using scipy's Rotation library
|
||||
rot_diff = Rotation.from_matrix(rot1 @ rot2.T)
|
||||
rotation_diff = rot_diff.as_rotvec() # Axis-angle representation
|
||||
|
||||
return np.concatenate([translation_diff, rotation_diff])
|
||||
|
||||
|
||||
def se3_error(target_pose: NDArray[np.float32], current_pose: NDArray[np.float32]) -> NDArray[np.float32]:
|
||||
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
|
||||
|
||||
rot_target = target_pose[:3, :3]
|
||||
rot_current = current_pose[:3, :3]
|
||||
rot_error_mat = rot_target @ rot_current.T
|
||||
rot_error = Rotation.from_matrix(rot_error_mat).as_rotvec()
|
||||
|
||||
return np.concatenate([pos_error, rot_error])
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
"""Robot kinematics class supporting multiple robot models."""
|
||||
|
||||
# Robot measurements dictionary
|
||||
ROBOT_MEASUREMENTS = {
|
||||
"koch": {
|
||||
"gripper": [0.239, -0.001, 0.024],
|
||||
"wrist": [0.209, 0, 0.024],
|
||||
"forearm": [0.108, 0, 0.02],
|
||||
"humerus": [0, 0, 0.036],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"moss": {
|
||||
"gripper": [0.246, 0.013, 0.111],
|
||||
"wrist": [0.245, 0.002, 0.064],
|
||||
"forearm": [0.122, 0, 0.064],
|
||||
"humerus": [0.001, 0.001, 0.063],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"so_old_calibration": {
|
||||
"gripper": [0.320, 0, 0.050],
|
||||
"wrist": [0.278, 0, 0.050],
|
||||
"forearm": [0.143, 0, 0.044],
|
||||
"humerus": [0.031, 0, 0.072],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"so_new_calibration": {
|
||||
"gripper": [0.33, 0.0, 0.285],
|
||||
"wrist": [0.30, 0.0, 0.267],
|
||||
"forearm": [0.25, 0.0, 0.266],
|
||||
"humerus": [0.06, 0.0, 0.264],
|
||||
"shoulder": [0.0, 0.0, 0.238],
|
||||
"base": [0.0, 0.0, 0.12],
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, robot_type: str = "so100"):
|
||||
"""Initialize kinematics for the specified robot type.
|
||||
|
||||
Args:
|
||||
robot_type: String specifying the robot model ("koch", "so100", or "moss")
|
||||
"""
|
||||
if robot_type not in self.ROBOT_MEASUREMENTS:
|
||||
raise ValueError(
|
||||
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
|
||||
)
|
||||
|
||||
self.robot_type = robot_type
|
||||
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
|
||||
|
||||
# Initialize all transformation matrices and screw axes
|
||||
self._setup_transforms()
|
||||
|
||||
def _create_translation_matrix(
|
||||
self, x: float = 0.0, y: float = 0.0, z: float = 0.0
|
||||
) -> NDArray[np.float32]:
|
||||
"""Create a 4x4 translation matrix."""
|
||||
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
|
||||
|
||||
def _setup_transforms(self):
|
||||
"""Setup all transformation matrices and screw axes for the robot."""
|
||||
# Set up rotation matrices (constant across robot types)
|
||||
|
||||
# Gripper orientation
|
||||
self.gripper_X0 = np.array(
|
||||
[
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Wrist orientation
|
||||
self.wrist_X0 = np.array(
|
||||
[
|
||||
[0, -1, 0, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Base orientation
|
||||
self.base_X0 = np.array(
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Gripper
|
||||
# Screw axis of gripper frame wrt base frame
|
||||
self.S_BG = np.array(
|
||||
[
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
self.measurements["gripper"][2],
|
||||
-self.measurements["gripper"][1],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Gripper origin to centroid transform
|
||||
self.X_GoGc = self._create_translation_matrix(x=0.07)
|
||||
|
||||
# Gripper origin to tip transform
|
||||
self.X_GoGt = self._create_translation_matrix(x=0.12)
|
||||
|
||||
# 0-position gripper frame pose wrt base
|
||||
self.X_BoGo = self._create_translation_matrix(
|
||||
x=self.measurements["gripper"][0],
|
||||
y=self.measurements["gripper"][1],
|
||||
z=self.measurements["gripper"][2],
|
||||
)
|
||||
|
||||
# Wrist
|
||||
# Screw axis of wrist frame wrt base frame
|
||||
self.S_BR = np.array(
|
||||
[0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]], dtype=np.float32
|
||||
)
|
||||
|
||||
# 0-position origin to centroid transform
|
||||
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
|
||||
|
||||
# 0-position wrist frame pose wrt base
|
||||
self.X_BR = self._create_translation_matrix(
|
||||
x=self.measurements["wrist"][0],
|
||||
y=self.measurements["wrist"][1],
|
||||
z=self.measurements["wrist"][2],
|
||||
)
|
||||
|
||||
# Forearm
|
||||
# Screw axis of forearm frame wrt base frame
|
||||
self.S_BF = np.array(
|
||||
[
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
-self.measurements["forearm"][2],
|
||||
0,
|
||||
self.measurements["forearm"][0],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Forearm origin + centroid transform
|
||||
self.X_ForearmFc = self._create_translation_matrix(x=0.036)
|
||||
|
||||
# 0-position forearm frame pose wrt base
|
||||
self.X_BF = self._create_translation_matrix(
|
||||
x=self.measurements["forearm"][0],
|
||||
y=self.measurements["forearm"][1],
|
||||
z=self.measurements["forearm"][2],
|
||||
)
|
||||
|
||||
# Humerus
|
||||
# Screw axis of humerus frame wrt base frame
|
||||
self.S_BH = np.array(
|
||||
[
|
||||
0,
|
||||
-1,
|
||||
0,
|
||||
self.measurements["humerus"][2],
|
||||
0,
|
||||
-self.measurements["humerus"][0],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Humerus origin to centroid transform
|
||||
self.X_HoHc = self._create_translation_matrix(x=0.0475)
|
||||
|
||||
# 0-position humerus frame pose wrt base
|
||||
self.X_BH = self._create_translation_matrix(
|
||||
x=self.measurements["humerus"][0],
|
||||
y=self.measurements["humerus"][1],
|
||||
z=self.measurements["humerus"][2],
|
||||
)
|
||||
|
||||
# Shoulder
|
||||
# Screw axis of shoulder frame wrt Base frame
|
||||
self.S_BS = np.array([0, 0, -1, 0, 0, 0], dtype=np.float32)
|
||||
|
||||
# Shoulder origin to centroid transform
|
||||
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
|
||||
|
||||
# 0-position shoulder frame pose wrt base
|
||||
self.X_BS = self._create_translation_matrix(
|
||||
x=self.measurements["shoulder"][0],
|
||||
y=self.measurements["shoulder"][1],
|
||||
z=self.measurements["shoulder"][2],
|
||||
)
|
||||
|
||||
# Base
|
||||
# Base origin to centroid transform
|
||||
self.X_BoBc = self._create_translation_matrix(y=0.015)
|
||||
|
||||
# World to base transform
|
||||
self.X_WoBo = self._create_translation_matrix(
|
||||
x=self.measurements["base"][0],
|
||||
y=self.measurements["base"][1],
|
||||
z=self.measurements["base"][2],
|
||||
)
|
||||
|
||||
# Pre-compute gripper post-multiplication matrix
|
||||
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
|
||||
|
||||
def forward_kinematics(
|
||||
self,
|
||||
robot_pos_deg: NDArray[np.float32],
|
||||
frame: str = "gripper_tip",
|
||||
) -> NDArray[np.float32]:
|
||||
"""Generic forward kinematics.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Joint positions in degrees. Can be ``None`` when
|
||||
computing the *base* frame as it does not depend on joint
|
||||
angles.
|
||||
frame: Target frame. One of
|
||||
``{"base", "shoulder", "humerus", "forearm", "wrist", "gripper", "gripper_tip"}``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
NDArray[np.float32]
|
||||
4×4 homogeneous transformation matrix of the requested frame
|
||||
expressed in the world coordinate system.
|
||||
"""
|
||||
frame = frame.lower()
|
||||
if frame not in {
|
||||
"base",
|
||||
"shoulder",
|
||||
"humerus",
|
||||
"forearm",
|
||||
"wrist",
|
||||
"gripper",
|
||||
"gripper_tip",
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Unknown frame '{frame}'. Valid options are base, shoulder, humerus, forearm, wrist, gripper, gripper_tip."
|
||||
)
|
||||
|
||||
# Base frame does not rely on joint angles.
|
||||
if frame == "base":
|
||||
return self.X_WoBo @ self.X_BoBc @ self.base_X0
|
||||
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
|
||||
# Extract joint angles (note the sign convention for shoulder lift).
|
||||
theta_shoulder_pan = robot_pos_rad[0]
|
||||
theta_shoulder_lift = -robot_pos_rad[1]
|
||||
theta_elbow_flex = robot_pos_rad[2]
|
||||
theta_wrist_flex = robot_pos_rad[3]
|
||||
theta_wrist_roll = robot_pos_rad[4]
|
||||
|
||||
# Start with the world-to-base transform; incrementally add successive links.
|
||||
transformation_matrix = self.X_WoBo @ screw_axis_to_transform(self.S_BS, theta_shoulder_pan)
|
||||
if frame == "shoulder":
|
||||
return transformation_matrix @ self.X_SoSc @ self.X_BS
|
||||
|
||||
transformation_matrix = transformation_matrix @ screw_axis_to_transform(
|
||||
self.S_BH, theta_shoulder_lift
|
||||
)
|
||||
if frame == "humerus":
|
||||
return transformation_matrix @ self.X_HoHc @ self.X_BH
|
||||
|
||||
transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BF, theta_elbow_flex)
|
||||
if frame == "forearm":
|
||||
return transformation_matrix @ self.X_ForearmFc @ self.X_BF
|
||||
|
||||
transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BR, theta_wrist_flex)
|
||||
if frame == "wrist":
|
||||
return transformation_matrix @ self.X_RoRc @ self.X_BR @ self.wrist_X0
|
||||
|
||||
transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BG, theta_wrist_roll)
|
||||
if frame == "gripper":
|
||||
return transformation_matrix @ self._fk_gripper_post
|
||||
else: # frame == "gripper_tip"
|
||||
return transformation_matrix @ self.X_GoGt @ self.X_BoGo @ self.gripper_X0
|
||||
|
||||
def compute_jacobian(
|
||||
self, robot_pos_deg: NDArray[np.float32], frame: str = "gripper_tip"
|
||||
) -> NDArray[np.float32]:
|
||||
"""Finite differences to compute the Jacobian.
|
||||
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
|
||||
in the jth joint's velocity.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Current joint positions in degrees
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
"""
|
||||
|
||||
eps = 1e-8
|
||||
jac = np.zeros(shape=(6, 5))
|
||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
sdot = (
|
||||
pose_difference_se3(
|
||||
self.forward_kinematics(robot_pos_deg[:-1] + delta, frame),
|
||||
self.forward_kinematics(robot_pos_deg[:-1] - delta, frame),
|
||||
)
|
||||
/ eps
|
||||
)
|
||||
jac[:, el_ix] = sdot
|
||||
return jac
|
||||
|
||||
def compute_positional_jacobian(
|
||||
self, robot_pos_deg: NDArray[np.float32], frame: str = "gripper_tip"
|
||||
) -> NDArray[np.float32]:
|
||||
"""Finite differences to compute the positional Jacobian.
|
||||
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
|
||||
in the jth joint's velocity.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Current joint positions in degrees
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
"""
|
||||
eps = 1e-8
|
||||
jac = np.zeros(shape=(3, 5))
|
||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
sdot = (
|
||||
self.forward_kinematics(robot_pos_deg[:-1] + delta, frame)[:3, 3]
|
||||
- self.forward_kinematics(robot_pos_deg[:-1] - delta, frame)[:3, 3]
|
||||
) / eps
|
||||
jac[:, el_ix] = sdot
|
||||
return jac
|
||||
|
||||
def ik(
|
||||
self,
|
||||
current_joint_pos: NDArray[np.float32],
|
||||
desired_ee_pose: NDArray[np.float32],
|
||||
position_only: bool = True,
|
||||
frame: str = "gripper_tip",
|
||||
max_iterations: int = 5,
|
||||
learning_rate: float = 1,
|
||||
) -> NDArray[np.float32]:
|
||||
"""Inverse kinematics using gradient descent.
|
||||
|
||||
Args:
|
||||
current_joint_state: Initial joint positions in degrees
|
||||
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
|
||||
position_only: If True, only match end-effector position, not orientation
|
||||
frame: Target frame. One of
|
||||
``{"base", "shoulder", "humerus", "forearm", "wrist", "gripper", "gripper_tip"}``.
|
||||
max_iterations: Maximum number of iterations to run
|
||||
learning_rate: Learning rate for gradient descent
|
||||
|
||||
Returns:
|
||||
Joint positions in degrees that achieve the desired end-effector pose
|
||||
"""
|
||||
# Do gradient descent.
|
||||
current_joint_state = current_joint_pos.copy()
|
||||
for _ in range(max_iterations):
|
||||
current_ee_pose = self.forward_kinematics(current_joint_state, frame)
|
||||
if not position_only:
|
||||
error = se3_error(desired_ee_pose, current_ee_pose)
|
||||
jac = self.compute_jacobian(current_joint_state, frame)
|
||||
else:
|
||||
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
|
||||
jac = self.compute_positional_jacobian(current_joint_state, frame)
|
||||
delta_angles = np.linalg.pinv(jac) @ error
|
||||
current_joint_state[:-1] += learning_rate * delta_angles
|
||||
|
||||
if np.linalg.norm(error) < 5e-3:
|
||||
return current_joint_state
|
||||
return current_joint_state
|
||||
@@ -14,8 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
@@ -44,7 +45,16 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer:
|
||||
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""
|
||||
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
||||
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
||||
For example, you can have one optimizer for the policy and another one for the value function
|
||||
in reinforcement learning settings.
|
||||
|
||||
Returns:
|
||||
The optimizer or a dictionary of optimizers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -94,7 +104,76 @@ class SGDConfig(OptimizerConfig):
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
@OptimizerConfig.register_subclass("multi_adam")
|
||||
@dataclass
|
||||
class MultiAdamConfig(OptimizerConfig):
|
||||
"""Configuration for multiple Adam optimizers with different parameter groups.
|
||||
|
||||
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
||||
|
||||
Args:
|
||||
lr: Default learning rate (used if not specified for a group)
|
||||
weight_decay: Default weight decay (used if not specified for a group)
|
||||
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
||||
grad_clip_norm: Gradient clipping norm
|
||||
"""
|
||||
|
||||
lr: float = 1e-3
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Build multiple Adam optimizers.
|
||||
|
||||
Args:
|
||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||
The keys should match the keys in optimizer_groups
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter group names to their optimizers
|
||||
"""
|
||||
optimizers = {}
|
||||
|
||||
for name, params in params_dict.items():
|
||||
# Get group-specific hyperparameters or use defaults
|
||||
group_config = self.optimizer_groups.get(name, {})
|
||||
|
||||
# Create optimizer with merged parameters (defaults + group-specific)
|
||||
optimizer_kwargs = {
|
||||
"lr": group_config.get("lr", self.lr),
|
||||
"betas": group_config.get("betas", (0.9, 0.999)),
|
||||
"eps": group_config.get("eps", 1e-5),
|
||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||
}
|
||||
|
||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def save_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> None:
|
||||
"""Save optimizer state to disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to save the optimizer state.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
||||
_save_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
# Handle single optimizer
|
||||
_save_single_optimizer_state(optimizer, save_dir)
|
||||
|
||||
|
||||
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
"""Save a single optimizer's state to disk."""
|
||||
state = optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
@@ -102,11 +181,44 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
def load_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""Load optimizer state from disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to load the optimizer state from.
|
||||
|
||||
Returns:
|
||||
The updated optimizer(s) with loaded state.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
loaded_optimizers = {}
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
if optimizer_dir.exists():
|
||||
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
loaded_optimizers[name] = opt
|
||||
return loaded_optimizers
|
||||
else:
|
||||
# Handle single optimizer
|
||||
return _load_single_optimizer_state(optimizer, save_dir)
|
||||
|
||||
|
||||
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
"""Load a single optimizer's state from disk."""
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
||||
# Handle case where 'state' key might not exist (for newly created optimizers)
|
||||
if "state" in state:
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
else:
|
||||
loaded_state_dict = {"state": {}}
|
||||
|
||||
if "param_groups" in current_state_dict:
|
||||
param_groups = deserialize_json_into_object(
|
||||
|
||||
@@ -27,6 +27,8 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
@@ -60,6 +62,14 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy
|
||||
elif name == "reward_classifier":
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
elif name == "smolvla":
|
||||
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
@@ -81,8 +91,12 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -151,6 +151,7 @@ class Normalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
@@ -252,3 +253,168 @@ class Unnormalize(nn.Module):
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> None:
|
||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||
|
||||
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
|
||||
but is factored out so it can be reused by both classes and stay in sync.
|
||||
"""
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape: tuple[int, ...] = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# reduce spatial dimensions, keep channel dimension only
|
||||
c, *_ = shape
|
||||
shape = (c, 1, 1)
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
std_data = stats[key]["std"]
|
||||
if isinstance(mean_data, torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_mean", mean)
|
||||
module.register_buffer(f"{prefix}_std", std)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_min", min_val)
|
||||
module.register_buffer(f"{prefix}_max", max_val)
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
|
||||
class NormalizeBuffer(nn.Module):
|
||||
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeBuffer(nn.Module):
|
||||
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max_val - min_val) + min_val
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
245
lerobot/common/policies/sac/configuration_sac.py
Normal file
245
lerobot/common/policies/sac/configuration_sac.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
def is_image_feature(key: str) -> bool:
|
||||
"""Check if a feature key represents an image feature.
|
||||
|
||||
Args:
|
||||
key: The feature key to check
|
||||
|
||||
Returns:
|
||||
True if the key represents an image feature, False otherwise
|
||||
"""
|
||||
return key.startswith(OBS_IMAGE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyConfig:
|
||||
"""Configuration for the concurrency of the actor and learner.
|
||||
Possible values are:
|
||||
- "threads": Use threads for the actor and learner.
|
||||
- "processes": Use processes for the actor and learner.
|
||||
"""
|
||||
|
||||
actor: str = "threads"
|
||||
learner: str = "threads"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
learner_port: int = 50051
|
||||
policy_parameters_push_frequency: int = 4
|
||||
queue_get_timeout: float = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriticNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
final_activation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
std_min: float = 1e-5
|
||||
std_max: float = 10.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
"""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Statistics for normalizing different types of inputs
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
OBS_IMAGE: {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
OBS_STATE: {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
ACTION: {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture specifics
|
||||
# Device to run the model on (e.g., "cuda", "cpu")
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
# Hidden dimension size for the image encoder
|
||||
image_encoder_hidden_dim: int = 32
|
||||
# Whether to use a shared encoder for actor and critic
|
||||
shared_encoder: bool = True
|
||||
# Number of discrete actions, eg for gripper actions
|
||||
num_discrete_actions: int | None = None
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Seed for the online environment
|
||||
online_env_seed: int = 10000
|
||||
# Capacity of the online replay buffer
|
||||
online_buffer_capacity: int = 100000
|
||||
# Capacity of the offline replay buffer
|
||||
offline_buffer_capacity: int = 100000
|
||||
# Whether to use asynchronous prefetching for the buffers
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
# Discount factor for the SAC algorithm
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for actor-learner architecture
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
},
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
has_image = any(is_image_feature(key) for key in self.input_features)
|
||||
has_state = OBS_STATE in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
raise ValueError(
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features if is_image_feature(key)]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return None # SAC typically predicts one action at a time
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
1111
lerobot/common/policies/sac/modeling_sac.py
Normal file
1111
lerobot/common/policies/sac/modeling_sac.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,76 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@dataclass
|
||||
class RewardClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Reward Classifier model."""
|
||||
|
||||
name: str = "reward_classifier"
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
learning_rate: float = 1e-4
|
||||
weight_decay: float = 0.01
|
||||
grad_clip_norm: float = 1.0
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
grad_clip_norm=self.grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate feature configurations."""
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
if not has_image:
|
||||
raise ValueError(
|
||||
"You must provide an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
316
lerobot/common/policies/sac/reward_model/modeling_classifier.py
Normal file
316
lerobot/common/policies/sac/reward_model/modeling_classifier.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_IMAGE, REWARD
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logits: Tensor,
|
||||
probabilities: Tensor | None = None,
|
||||
hidden_states: Tensor | None = None,
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class SpatialLearnedEmbeddings(nn.Module):
|
||||
def __init__(self, height, width, channel, num_features=8):
|
||||
"""
|
||||
PyTorch implementation of learned spatial embeddings
|
||||
|
||||
Args:
|
||||
height: Spatial height of input features
|
||||
width: Spatial width of input features
|
||||
channel: Number of input channels
|
||||
num_features: Number of output embedding dimensions
|
||||
"""
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channel = channel
|
||||
self.num_features = num_features
|
||||
|
||||
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||
|
||||
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Forward pass for spatial embedding
|
||||
|
||||
Args:
|
||||
features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch
|
||||
Returns:
|
||||
Output tensor of shape [B, C*F] or [C*F] if no batch
|
||||
"""
|
||||
|
||||
features = features.last_hidden_state
|
||||
|
||||
original_shape = features.shape
|
||||
if features.dim() == 3:
|
||||
features = features.unsqueeze(0) # Add batch dim
|
||||
|
||||
features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1]
|
||||
kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F]
|
||||
|
||||
# Element-wise multiplication and spatial reduction
|
||||
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W
|
||||
|
||||
# Reshape to combine channel and feature dimensions
|
||||
output = output.view(output.size(0), -1) # [B, C*F]
|
||||
|
||||
# Remove batch dim
|
||||
if len(original_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Classifier(PreTrainedPolicy):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "reward_classifier"
|
||||
config_class = RewardClassifierConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
|
||||
# Extract image keys from input_features
|
||||
self.image_keys = [
|
||||
key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE)
|
||||
]
|
||||
|
||||
if self.is_cnn:
|
||||
self.encoders = nn.ModuleDict()
|
||||
for image_key in self.image_keys:
|
||||
encoder = self._create_single_encoder()
|
||||
self.encoders[image_key] = encoder
|
||||
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _create_single_encoder(self):
|
||||
encoder = nn.Sequential(
|
||||
self.encoder,
|
||||
SpatialLearnedEmbeddings(
|
||||
height=4,
|
||||
width=4,
|
||||
channel=self.feature_dim,
|
||||
num_features=self.config.image_embedding_pooling_dim,
|
||||
),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
return encoder
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.config.latent_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(
|
||||
self.config.hidden_dim,
|
||||
1 if self.config.num_classes == 2 else self.config.num_classes,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoders[image_key](x)
|
||||
return outputs
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(x)
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def extract_images_and_labels(self, batch: dict[str, Tensor]) -> tuple[list, Tensor]:
|
||||
"""Extract image tensors and label tensors from batch."""
|
||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
labels = batch[REWARD]
|
||||
|
||||
return images, labels
|
||||
|
||||
def predict(self, xs: list) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier for inference."""
|
||||
encoder_outputs = torch.hstack(
|
||||
[self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)]
|
||||
)
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
# Get predictions
|
||||
outputs = self.predict(images)
|
||||
|
||||
# Calculate loss
|
||||
if self.config.num_classes == 2:
|
||||
# Binary classification
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
# Multi-class classification
|
||||
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
|
||||
# Calculate accuracy for logging
|
||||
correct = (predictions == labels).sum().item()
|
||||
total = labels.size(0)
|
||||
accuracy = 100 * correct / total
|
||||
|
||||
# Return loss and metrics for logging
|
||||
output_dict = {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def predict_reward(self, batch, threshold=0.5):
|
||||
"""Eval method. Returns predicted reward with the decision threshold as argument."""
|
||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images from batch dict
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.predict(images).probabilities
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Return optimizer parameters for the policy."""
|
||||
return self.parameters()
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not select actions")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
pass
|
||||
@@ -1,2 +1,3 @@
|
||||
from .config_so100_follower import SO100FollowerConfig
|
||||
from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig
|
||||
from .so100_follower import SO100Follower
|
||||
from .so100_follower_end_effector import SO100FollowerEndEffector
|
||||
|
||||
@@ -37,3 +37,27 @@ class SO100FollowerConfig(RobotConfig):
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100_follower_end_effector")
|
||||
@dataclass
|
||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: dict[str, list[float]] = field(
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
|
||||
max_gripper_pos: float = 50
|
||||
|
||||
end_effector_step_sizes: dict[str, float] = field(
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.cameras import make_cameras_from_configs
|
||||
from lerobot.common.errors import DeviceNotConnectedError
|
||||
from lerobot.common.model.kinematics import RobotKinematics
|
||||
from lerobot.common.motors import Motor, MotorNormMode
|
||||
from lerobot.common.motors.feetech import FeetechMotorsBus
|
||||
|
||||
from . import SO100Follower
|
||||
from .config_so100_follower import SO100FollowerEndEffectorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
EE_FRAME = "gripper_tip"
|
||||
|
||||
|
||||
class SO100FollowerEndEffector(SO100Follower):
|
||||
"""
|
||||
SO100Follower robot with end-effector space control.
|
||||
|
||||
This robot inherits from SO100Follower but transforms actions from
|
||||
end-effector space to joint space before sending them to the motors.
|
||||
"""
|
||||
|
||||
config_class = SO100FollowerEndEffectorConfig
|
||||
name = "so100_follower_end_effector"
|
||||
|
||||
def __init__(self, config: SO100FollowerEndEffectorConfig):
|
||||
super().__init__(config)
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES),
|
||||
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES),
|
||||
"elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES),
|
||||
"wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES),
|
||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES),
|
||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.config = config
|
||||
|
||||
# Initialize the kinematics module for the so100 robot
|
||||
self.kinematics = RobotKinematics(robot_type="so_new_calibration")
|
||||
|
||||
# Store the bounds for end-effector position
|
||||
self.end_effector_bounds = self.config.end_effector_bounds
|
||||
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, Any]:
|
||||
"""
|
||||
Define action features for end-effector control.
|
||||
Returns dictionary with dtype, shape, and names.
|
||||
"""
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform action from end-effector space to joint space and send to motors.
|
||||
|
||||
Args:
|
||||
action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control
|
||||
or a numpy array with [delta_x, delta_y, delta_z]
|
||||
|
||||
Returns:
|
||||
The joint-space action that was sent to the motors
|
||||
"""
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Convert action to numpy array if not already
|
||||
if isinstance(action, dict):
|
||||
if all(k in action for k in ["delta_x", "delta_y", "delta_z"]):
|
||||
delta_ee = np.array(
|
||||
[
|
||||
action["delta_x"] * self.config.end_effector_step_sizes["x"],
|
||||
action["delta_y"] * self.config.end_effector_step_sizes["y"],
|
||||
action["delta_z"] * self.config.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
if "gripper" not in action:
|
||||
action["gripper"] = [1.0]
|
||||
action = np.append(delta_ee, action["gripper"])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}"
|
||||
)
|
||||
action = np.zeros(4, dtype=np.float32)
|
||||
|
||||
if self.current_joint_pos is None:
|
||||
# Read current joint positions
|
||||
current_joint_pos = self.bus.sync_read("Present_Position")
|
||||
self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors])
|
||||
|
||||
# Calculate current end-effector position using forward kinematics
|
||||
if self.current_ee_pos is None:
|
||||
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos, frame=EE_FRAME)
|
||||
|
||||
# Set desired end-effector position by adding delta
|
||||
desired_ee_pos = np.eye(4)
|
||||
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
|
||||
|
||||
# Add delta to position and clip to bounds
|
||||
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3]
|
||||
if self.end_effector_bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(
|
||||
desired_ee_pos[:3, 3],
|
||||
self.end_effector_bounds["min"],
|
||||
self.end_effector_bounds["max"],
|
||||
)
|
||||
|
||||
# Compute inverse kinematics to get joint positions
|
||||
target_joint_values_in_degrees = self.kinematics.ik(
|
||||
self.current_joint_pos, desired_ee_pos, position_only=True, frame=EE_FRAME
|
||||
)
|
||||
|
||||
target_joint_values_in_degrees = np.clip(target_joint_values_in_degrees, -180.0, 180.0)
|
||||
# Create joint space action dictionary
|
||||
joint_action = {
|
||||
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
|
||||
}
|
||||
|
||||
# Handle gripper separately if included in action
|
||||
# Gripper delta action is in the range 0 - 2,
|
||||
# We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos
|
||||
joint_action["gripper.pos"] = np.clip(
|
||||
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
|
||||
5,
|
||||
self.config.max_gripper_pos,
|
||||
)
|
||||
|
||||
self.current_ee_pos = desired_ee_pos.copy()
|
||||
self.current_joint_pos = target_joint_values_in_degrees.copy()
|
||||
self.current_joint_pos[-1] = joint_action["gripper.pos"]
|
||||
|
||||
# Send joint space action to parent class
|
||||
return super().send_action(joint_action)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def reset(self):
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
@@ -29,6 +29,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .so100_follower import SO100Follower
|
||||
|
||||
return SO100Follower(config)
|
||||
elif config.type == "so100_follower_end_effector":
|
||||
from .so100_follower import SO100FollowerEndEffector
|
||||
|
||||
return SO100FollowerEndEffector(config)
|
||||
elif config.type == "so101_follower":
|
||||
from .so101_follower import SO101Follower
|
||||
|
||||
|
||||
18
lerobot/common/teleoperators/gamepad/__init__.py
Normal file
18
lerobot/common/teleoperators/gamepad/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_gamepad import GamepadTeleopConfig
|
||||
from .teleop_gamepad import GamepadTeleop
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("gamepad")
|
||||
@dataclass
|
||||
class GamepadTeleopConfig(TeleoperatorConfig):
|
||||
use_gripper: bool = True
|
||||
480
lerobot/common/teleoperators/gamepad/gamepad_utils.py
Normal file
480
lerobot/common/teleoperators/gamepad/gamepad_utils.py
Normal file
@@ -0,0 +1,480 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
||||
"""
|
||||
Initialize the controller.
|
||||
|
||||
Args:
|
||||
x_step_size: Base movement step size in meters
|
||||
y_step_size: Base movement step size in meters
|
||||
z_step_size: Base movement step size in meters
|
||||
"""
|
||||
self.x_step_size = x_step_size
|
||||
self.y_step_size = y_step_size
|
||||
self.z_step_size = z_step_size
|
||||
self.running = True
|
||||
self.episode_end_status = None # None, "success", or "failure"
|
||||
self.intervention_flag = False
|
||||
self.open_gripper_command = False
|
||||
self.close_gripper_command = False
|
||||
|
||||
def start(self):
|
||||
"""Start the controller and initialize resources."""
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and release resources."""
|
||||
pass
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if the user has requested to quit."""
|
||||
return not self.running
|
||||
|
||||
def update(self):
|
||||
"""Update controller state - call this once per frame."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for use in 'with' statements."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Ensure resources are released when exiting 'with' block."""
|
||||
self.stop()
|
||||
|
||||
def get_episode_end_status(self):
|
||||
"""
|
||||
Get the current episode end status.
|
||||
|
||||
Returns:
|
||||
None if episode should continue, "success" or "failure" otherwise
|
||||
"""
|
||||
status = self.episode_end_status
|
||||
self.episode_end_status = None # Reset after reading
|
||||
return status
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
def gripper_command(self):
|
||||
"""Return the current gripper command."""
|
||||
if self.open_gripper_command == self.close_gripper_command:
|
||||
return "stay"
|
||||
elif self.open_gripper_command:
|
||||
return "open"
|
||||
elif self.close_gripper_command:
|
||||
return "close"
|
||||
|
||||
|
||||
class KeyboardController(InputController):
|
||||
"""Generate motion deltas from keyboard input."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.key_states = {
|
||||
"forward_x": False,
|
||||
"backward_x": False,
|
||||
"forward_y": False,
|
||||
"backward_y": False,
|
||||
"forward_z": False,
|
||||
"backward_z": False,
|
||||
"quit": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
}
|
||||
self.listener = None
|
||||
|
||||
def start(self):
|
||||
"""Start the keyboard listener."""
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = True
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = True
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = True
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["quit"] = True
|
||||
self.running = False
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = "success"
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = "failure"
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def on_release(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = False
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = False
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = False
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = False
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = False
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = False
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
|
||||
self.listener.start()
|
||||
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" Enter: End episode with SUCCESS")
|
||||
print(" Backspace: End episode with FAILURE")
|
||||
print(" ESC: Exit")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
if self.listener and self.listener.is_alive():
|
||||
self.listener.stop()
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from keyboard state."""
|
||||
delta_x = delta_y = delta_z = 0.0
|
||||
|
||||
if self.key_states["forward_x"]:
|
||||
delta_x += self.x_step_size
|
||||
if self.key_states["backward_x"]:
|
||||
delta_x -= self.x_step_size
|
||||
if self.key_states["forward_y"]:
|
||||
delta_y += self.y_step_size
|
||||
if self.key_states["backward_y"]:
|
||||
delta_y -= self.y_step_size
|
||||
if self.key_states["forward_z"]:
|
||||
delta_z += self.z_step_size
|
||||
if self.key_states["backward_z"]:
|
||||
delta_z -= self.z_step_size
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if ESC was pressed."""
|
||||
return self.key_states["quit"]
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if Enter was pressed (save episode)."""
|
||||
return self.key_states["success"] or self.key_states["failure"]
|
||||
|
||||
|
||||
class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
self.intervention_flag = False
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
|
||||
if pygame.joystick.get_count() == 0:
|
||||
logging.error("No gamepad detected. Please connect a gamepad and try again.")
|
||||
self.running = False
|
||||
return
|
||||
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
||||
|
||||
print("Gamepad controls:")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick (vertical): Move in Z axis")
|
||||
print(" B/Circle button: Exit")
|
||||
print(" Y/Triangle button: End episode with SUCCESS")
|
||||
print(" A/Cross button: End episode with FAILURE")
|
||||
print(" X/Square button: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
if self.joystick:
|
||||
self.joystick.quit()
|
||||
pygame.joystick.quit()
|
||||
pygame.quit()
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = "success"
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = "failure"
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# LT button (7) for opening gripper
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [0, 2, 3]:
|
||||
self.episode_end_status = None
|
||||
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = False
|
||||
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = False
|
||||
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
self.intervention_flag = True
|
||||
else:
|
||||
self.intervention_flag = False
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
x_input = self.joystick.get_axis(0) # Left/Right
|
||||
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
|
||||
|
||||
# Right stick Y (typically axis 3 or 4)
|
||||
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
||||
|
||||
# Apply deadzone to avoid drift
|
||||
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||
|
||||
# Calculate deltas (note: may need to invert axes depending on controller)
|
||||
delta_x = -y_input * self.y_step_size # Forward/backward
|
||||
delta_y = -x_input * self.x_step_size # Left/right
|
||||
delta_z = -z_input * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
except pygame.error:
|
||||
logging.error("Error reading gamepad. Is it still connected?")
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
|
||||
class GamepadControllerHID(InputController):
|
||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_step_size=1.0,
|
||||
y_step_size=1.0,
|
||||
z_step_size=1.0,
|
||||
deadzone=0.1,
|
||||
):
|
||||
"""
|
||||
Initialize the HID gamepad controller.
|
||||
|
||||
Args:
|
||||
step_size: Base movement step size in meters
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
"""
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.device = None
|
||||
self.device_info = None
|
||||
|
||||
# Movement values (normalized from -1.0 to 1.0)
|
||||
self.left_x = 0.0
|
||||
self.left_y = 0.0
|
||||
self.right_x = 0.0
|
||||
self.right_y = 0.0
|
||||
|
||||
# Button states
|
||||
self.buttons = {}
|
||||
self.quit_requested = False
|
||||
self.save_requested = False
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
device_name = device["product_string"]
|
||||
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5"]):
|
||||
return device
|
||||
|
||||
logging.error(
|
||||
"No gamepad found, check the connection and the product string in HID to add your gamepad"
|
||||
)
|
||||
return None
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
if not self.device_info:
|
||||
self.running = False
|
||||
return
|
||||
|
||||
try:
|
||||
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
|
||||
self.device = hid.device()
|
||||
self.device.open_path(self.device_info["path"])
|
||||
self.device.set_nonblocking(1)
|
||||
|
||||
manufacturer = self.device.get_manufacturer_string()
|
||||
product = self.device.get_product_string()
|
||||
logging.info(f"Connected to {manufacturer} {product}")
|
||||
|
||||
logging.info("Gamepad controls (HID mode):")
|
||||
logging.info(" Left analog stick: Move in X-Y plane")
|
||||
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
||||
logging.info(" Button 1/B/Circle: Exit")
|
||||
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
||||
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error opening gamepad: {e}")
|
||||
logging.error("You might need to run this with sudo/admin privileges on some systems")
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
"""Close the HID device connection."""
|
||||
if self.device:
|
||||
self.device.close()
|
||||
self.device = None
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Read and process the latest gamepad data.
|
||||
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
||||
"""
|
||||
for _ in range(10):
|
||||
self._update()
|
||||
|
||||
def _update(self):
|
||||
"""Read and process the latest gamepad data."""
|
||||
if not self.device or not self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Read data from the gamepad
|
||||
data = self.device.read(64)
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if data and len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
# Apply deadzone
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error reading from gamepad: {e}")
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
# Calculate deltas - invert as needed based on controller orientation
|
||||
delta_x = -self.left_y * self.x_step_size # Forward/backward
|
||||
delta_y = -self.left_x * self.y_step_size # Left/right
|
||||
delta_z = -self.right_y * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if quit button was pressed."""
|
||||
return self.quit_requested
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if save button was pressed."""
|
||||
return self.save_requested
|
||||
138
lerobot/common/teleoperators/gamepad/teleop_gamepad.py
Normal file
138
lerobot/common/teleoperators/gamepad/teleop_gamepad.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .configuration_gamepad import GamepadTeleopConfig
|
||||
|
||||
|
||||
class GripperAction(IntEnum):
|
||||
CLOSE = 0
|
||||
STAY = 1
|
||||
OPEN = 2
|
||||
|
||||
|
||||
gripper_action_map = {
|
||||
"close": GripperAction.CLOSE.value,
|
||||
"open": GripperAction.OPEN.value,
|
||||
"stay": GripperAction.STAY.value,
|
||||
}
|
||||
|
||||
|
||||
class GamepadTeleop(Teleoperator):
|
||||
"""
|
||||
Teleop class to use gamepad inputs for control.
|
||||
"""
|
||||
|
||||
config_class = GamepadTeleopConfig
|
||||
name = "gamepad"
|
||||
|
||||
def __init__(self, config: GamepadTeleopConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.robot_type = config.type
|
||||
|
||||
self.gamepad = None
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
if self.config.use_gripper:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (3,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
|
||||
}
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict:
|
||||
return {}
|
||||
|
||||
def connect(self) -> None:
|
||||
# use HidApi for macos
|
||||
if sys.platform == "darwin":
|
||||
# NOTE: On macOS, pygame doesn’t reliably detect input from some controllers so we fall back to hidapi
|
||||
from .gamepad_utils import GamepadControllerHID as Gamepad
|
||||
else:
|
||||
from .gamepad_utils import GamepadController as Gamepad
|
||||
|
||||
self.gamepad = Gamepad()
|
||||
self.gamepad.start()
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
# Update the controller to get fresh inputs
|
||||
self.gamepad.update()
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = self.gamepad.get_deltas()
|
||||
|
||||
# Create action from gamepad input
|
||||
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
||||
|
||||
action_dict = {
|
||||
"delta_x": gamepad_action[0],
|
||||
"delta_y": gamepad_action[1],
|
||||
"delta_z": gamepad_action[2],
|
||||
}
|
||||
|
||||
# Default gripper action is to stay
|
||||
gripper_action = GripperAction.STAY.value
|
||||
if self.config.use_gripper:
|
||||
gripper_command = self.gamepad.gripper_command()
|
||||
gripper_action = gripper_action_map[gripper_command]
|
||||
action_dict["gripper"] = gripper_action
|
||||
|
||||
return action_dict
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the gamepad."""
|
||||
if self.gamepad is not None:
|
||||
self.gamepad.stop()
|
||||
self.gamepad = None
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if gamepad is connected."""
|
||||
return self.gamepad is not None
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""Calibrate the gamepad."""
|
||||
# No calibration needed for gamepad
|
||||
pass
|
||||
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if gamepad is calibrated."""
|
||||
# Gamepad doesn't require calibration
|
||||
return True
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Configure the gamepad."""
|
||||
# No additional configuration needed
|
||||
pass
|
||||
|
||||
def send_feedback(self, feedback: dict) -> None:
|
||||
"""Send feedback to the gamepad."""
|
||||
# Gamepad doesn't support feedback
|
||||
pass
|
||||
@@ -24,3 +24,5 @@ from ..config import TeleoperatorConfig
|
||||
class SO101LeaderConfig(TeleoperatorConfig):
|
||||
# Port to connect to the arm
|
||||
port: str
|
||||
|
||||
use_degrees: bool = False
|
||||
|
||||
@@ -41,14 +41,15 @@ class SO101Leader(Teleoperator):
|
||||
def __init__(self, config: SO101LeaderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"elbow_flex": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_flex": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder_pan": Motor(1, "sts3215", norm_mode_body),
|
||||
"shoulder_lift": Motor(2, "sts3215", norm_mode_body),
|
||||
"elbow_flex": Motor(3, "sts3215", norm_mode_body),
|
||||
"wrist_flex": Motor(4, "sts3215", norm_mode_body),
|
||||
"wrist_roll": Motor(5, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
|
||||
@@ -45,5 +45,9 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
from tests.mocks.mock_teleop import MockTeleop
|
||||
|
||||
return MockTeleop(config)
|
||||
elif config.type == "gamepad":
|
||||
from .gamepad.teleop_gamepad import GamepadTeleop
|
||||
|
||||
return GamepadTeleop(config)
|
||||
else:
|
||||
raise ValueError(config.type)
|
||||
|
||||
59
lerobot/common/transport/services.proto
Normal file
59
lerobot/common/transport/services.proto
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command:
|
||||
//
|
||||
// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. lerobot/common/transport/services.proto
|
||||
//
|
||||
// The command should be launched from the root of the project.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package transport;
|
||||
|
||||
// LearnerService: the Actor calls this to push transitions.
|
||||
// The Learner implements this service.
|
||||
service LearnerService {
|
||||
// Actor -> Learner to store transitions
|
||||
rpc StreamParameters(Empty) returns (stream Parameters);
|
||||
rpc SendTransitions(stream Transition) returns (Empty);
|
||||
rpc SendInteractions(stream InteractionMessage) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Transition {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Parameters {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message InteractionMessage {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
45
lerobot/common/transport/services_pb2.py
Normal file
45
lerobot/common/transport/services_pb2.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: lerobot/common/transport/services.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'lerobot/common/transport/services.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'lerobot/common/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.common.transport.services_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=305
|
||||
_globals['_TRANSFERSTATE']._serialized_end=401
|
||||
_globals['_TRANSITION']._serialized_start=54
|
||||
_globals['_TRANSITION']._serialized_end=130
|
||||
_globals['_PARAMETERS']._serialized_start=132
|
||||
_globals['_PARAMETERS']._serialized_end=208
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=210
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=294
|
||||
_globals['_EMPTY']._serialized_start=296
|
||||
_globals['_EMPTY']._serialized_end=303
|
||||
_globals['_LEARNERSERVICE']._serialized_start=404
|
||||
_globals['_LEARNERSERVICE']._serialized_end=661
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
233
lerobot/common/transport/services_pb2_grpc.py
Normal file
233
lerobot/common/transport/services_pb2_grpc.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from lerobot.common.transport import services_pb2 as lerobot_dot_common_dot_transport_dot_services__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.71.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in lerobot/common/transport/services_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class LearnerServiceStub:
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.StreamParameters = channel.unary_stream(
|
||||
'/transport.LearnerService/StreamParameters',
|
||||
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTransitions = channel.stream_unary(
|
||||
'/transport.LearnerService/SendTransitions',
|
||||
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString,
|
||||
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendInteractions = channel.stream_unary(
|
||||
'/transport.LearnerService/SendInteractions',
|
||||
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/transport.LearnerService/Ready',
|
||||
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class LearnerServiceServicer:
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def StreamParameters(self, request, context):
|
||||
"""Actor -> Learner to store transitions
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendTransitions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendInteractions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_LearnerServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'StreamParameters': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamParameters,
|
||||
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.SerializeToString,
|
||||
),
|
||||
'SendTransitions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendTransitions,
|
||||
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.FromString,
|
||||
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendInteractions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendInteractions,
|
||||
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString,
|
||||
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'transport.LearnerService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('transport.LearnerService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class LearnerService:
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def StreamParameters(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/transport.LearnerService/StreamParameters',
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendTransitions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/transport.LearnerService/SendTransitions',
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString,
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendInteractions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/transport.LearnerService/SendInteractions',
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.LearnerService/Ready',
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
141
lerobot/common/transport/utils.py
Normal file
141
lerobot/common/transport/utils.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import pickle # nosec B403: Safe usage for internal serialization only
|
||||
from multiprocessing import Event, Queue
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.transport import services_pb2
|
||||
from lerobot.common.utils.transition import Transition
|
||||
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
|
||||
|
||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
buffer.seek(0, io.SEEK_END)
|
||||
result = buffer.tell()
|
||||
buffer.seek(0)
|
||||
return result
|
||||
|
||||
|
||||
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging_method = logging.info if not silent else logging.debug
|
||||
|
||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||
transfer_state = services_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = services_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
|
||||
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
logging.info(f"{log_prefix} Starting receiver")
|
||||
for item in iterator:
|
||||
logging.debug(f"{log_prefix} Received item")
|
||||
if shutdown_event.is_set():
|
||||
logging.info(f"{log_prefix} Shutting down receiver")
|
||||
return
|
||||
|
||||
if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step 0")
|
||||
step = 0
|
||||
elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE:
|
||||
bytes_buffer.write(item.data)
|
||||
step += 1
|
||||
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||
elif item.transfer_state == services_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||
|
||||
queue.put(bytes_buffer.getvalue())
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
step = 0
|
||||
|
||||
logging.debug(f"{log_prefix} Queue updated")
|
||||
else:
|
||||
logging.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
|
||||
raise ValueError(f"Received unknown transfer state {item.transfer_state}")
|
||||
|
||||
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
"""Convert model state dict to flat array for transmission"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
torch.save(state_dict, buffer)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer, weights_only=True)
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
return pickle.dumps(python_object)
|
||||
|
||||
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
||||
# Add validation checks here
|
||||
return obj
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
transitions = torch.load(buffer, weights_only=True)
|
||||
return transitions
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(transitions, buffer)
|
||||
return buffer.getvalue()
|
||||
841
lerobot/common/utils/buffer.py
Normal file
841
lerobot/common/utils/buffer.py
Normal file
@@ -0,0 +1,841 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from contextlib import suppress
|
||||
from typing import Callable, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.utils.transition import Transition
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
truncated: torch.Tensor
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
(Same as shown previously.)
|
||||
"""
|
||||
B, C, H, W = images.shape # noqa: N806
|
||||
crop_h, crop_w = output_size
|
||||
|
||||
if crop_h > H or crop_w > W:
|
||||
raise ValueError(
|
||||
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||
)
|
||||
|
||||
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||
|
||||
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||
|
||||
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
return cropped
|
||||
|
||||
|
||||
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||
_, _, h, w = images.shape
|
||||
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Sequence[str] | None = None,
|
||||
image_augmentation_function: Callable | None = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
):
|
||||
"""
|
||||
Replay buffer for storing transitions.
|
||||
It will allocate tensors on the specified device, when the first transition is added.
|
||||
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
|
||||
and use the `storage_device` flag to store the buffer on a different device.
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
||||
Using "cpu" can help save GPU memory.
|
||||
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
|
||||
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
|
||||
"""
|
||||
if capacity <= 0:
|
||||
raise ValueError("Capacity must be greater than 0.")
|
||||
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
self.position = 0
|
||||
self.size = 0
|
||||
self.initialized = False
|
||||
self.optimize_memory = optimize_memory
|
||||
|
||||
# Track episode boundaries for memory optimization
|
||||
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
|
||||
self.image_augmentation_function = image_augmentation_function
|
||||
|
||||
if image_augmentation_function is None:
|
||||
base_function = functools.partial(random_shift, pad=4)
|
||||
self.image_augmentation_function = torch.compile(base_function)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def _initialize_storage(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
||||
):
|
||||
"""Initialize the storage tensors based on the first transition."""
|
||||
# Determine shapes from the first transition
|
||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||
action_shape = action.squeeze(0).shape
|
||||
|
||||
# Pre-allocate tensors for storage
|
||||
self.states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
|
||||
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach: store states and next_states separately
|
||||
self.next_states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
else:
|
||||
# Memory-optimized approach: don't allocate next_states buffer
|
||||
# Just create a reference to states for consistent API
|
||||
self.next_states = self.states # Just a reference for API consistency
|
||||
|
||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
|
||||
# Initialize storage for complementary_info
|
||||
self.has_complementary_info = complementary_info is not None
|
||||
self.complementary_info_keys = []
|
||||
self.complementary_info = {}
|
||||
|
||||
if self.has_complementary_info:
|
||||
self.complementary_info_keys = list(complementary_info.keys())
|
||||
# Pre-allocate tensors for each key in complementary_info
|
||||
for key, value in complementary_info.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
value_shape = value.squeeze(0).shape
|
||||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def add(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
truncated: bool,
|
||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Initialize storage if this is the first transition
|
||||
if not self.initialized:
|
||||
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
||||
|
||||
# Store the transition in pre-allocated tensors
|
||||
for key in self.states:
|
||||
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Only store next_states if not optimizing memory
|
||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
|
||||
# Handle complementary_info if provided and storage is initialized
|
||||
if complementary_info is not None and self.has_complementary_info:
|
||||
# Store the complementary_info
|
||||
for key in self.complementary_info_keys:
|
||||
if key in complementary_info:
|
||||
value = complementary_info[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||
elif isinstance(value, (int, float)):
|
||||
self.complementary_info[key][self.position] = value
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
if not self.initialized:
|
||||
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
||||
|
||||
batch_size = min(batch_size, self.size)
|
||||
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
||||
|
||||
# Random indices for sampling - create on the same device as storage
|
||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
batch_next_state = {}
|
||||
|
||||
# First pass: load all state tensors to target device
|
||||
for key in self.states:
|
||||
batch_state[key] = self.states[key][idx].to(self.device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach - load next_states directly
|
||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||
else:
|
||||
# Memory-optimized approach - get next_state from the next index
|
||||
next_idx = (idx + 1) % self.capacity
|
||||
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
||||
|
||||
# Apply image augmentation in a batched way if needed
|
||||
if self.use_drq and image_keys:
|
||||
# Concatenate all images from state and next_state
|
||||
all_images = []
|
||||
for key in image_keys:
|
||||
all_images.append(batch_state[key])
|
||||
all_images.append(batch_next_state[key])
|
||||
|
||||
# Optimization: Batch all images and apply augmentation once
|
||||
all_images_tensor = torch.cat(all_images, dim=0)
|
||||
augmented_images = self.image_augmentation_function(all_images_tensor)
|
||||
|
||||
# Split the augmented images back to their sources
|
||||
for i, key in enumerate(image_keys):
|
||||
# Calculate offsets for the current image key:
|
||||
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
|
||||
# States start at index i*2*batch_size and take up batch_size slots
|
||||
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
|
||||
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||
|
||||
# Sample other tensors
|
||||
batch_actions = self.actions[idx].to(self.device)
|
||||
batch_rewards = self.rewards[idx].to(self.device)
|
||||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
|
||||
# Sample complementary_info if available
|
||||
batch_complementary_info = None
|
||||
if self.has_complementary_info:
|
||||
batch_complementary_info = {}
|
||||
for key in self.complementary_info_keys:
|
||||
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
||||
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
complementary_info=batch_complementary_info,
|
||||
)
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""
|
||||
Creates an infinite iterator that yields batches of transitions.
|
||||
Will automatically restart when internal iterator is exhausted.
|
||||
|
||||
Args:
|
||||
batch_size (int): Size of batches to sample
|
||||
async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True)
|
||||
queue_size (int): Number of batches to prefetch (default: 2)
|
||||
|
||||
Yields:
|
||||
BatchTransition: Batched transitions
|
||||
"""
|
||||
while True: # Create an infinite loop
|
||||
if async_prefetch:
|
||||
# Get the standard iterator
|
||||
iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size)
|
||||
else:
|
||||
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
|
||||
|
||||
# Yield all items from the iterator
|
||||
with suppress(StopIteration):
|
||||
yield from iterator
|
||||
|
||||
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
|
||||
"""
|
||||
Create an iterator that continuously yields prefetched batches in a
|
||||
background thread. The design is intentionally simple and avoids busy
|
||||
waiting / complex state management.
|
||||
|
||||
Args:
|
||||
batch_size (int): Size of batches to sample.
|
||||
queue_size (int): Maximum number of prefetched batches to keep in
|
||||
memory.
|
||||
|
||||
Yields:
|
||||
BatchTransition: A batch sampled from the replay buffer.
|
||||
"""
|
||||
import queue
|
||||
import threading
|
||||
|
||||
data_queue: queue.Queue = queue.Queue(maxsize=queue_size)
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
def producer() -> None:
|
||||
"""Continuously put sampled batches into the queue until shutdown."""
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
batch = self.sample(batch_size)
|
||||
# The timeout ensures the thread unblocks if the queue is full
|
||||
# and the shutdown event gets set meanwhile.
|
||||
data_queue.put(batch, block=True, timeout=0.5)
|
||||
except queue.Full:
|
||||
# Queue is full – loop again (will re-check shutdown_event)
|
||||
continue
|
||||
except Exception:
|
||||
# Surface any unexpected error and terminate the producer.
|
||||
shutdown_event.set()
|
||||
|
||||
producer_thread = threading.Thread(target=producer, daemon=True)
|
||||
producer_thread.start()
|
||||
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
yield data_queue.get(block=True)
|
||||
except Exception:
|
||||
# If the producer already set the shutdown flag we exit.
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
finally:
|
||||
shutdown_event.set()
|
||||
# Drain the queue quickly to help the thread exit if it's blocked on `put`.
|
||||
while not data_queue.empty():
|
||||
_ = data_queue.get_nowait()
|
||||
# Give the producer thread a bit of time to finish.
|
||||
producer_thread.join(timeout=1.0)
|
||||
|
||||
def _get_naive_iterator(self, batch_size: int, queue_size: int = 2):
|
||||
"""
|
||||
Creates a simple non-threaded iterator that yields batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): Size of batches to sample
|
||||
queue_size (int): Number of initial batches to prefetch
|
||||
|
||||
Yields:
|
||||
BatchTransition: Batch transitions
|
||||
"""
|
||||
import collections
|
||||
|
||||
queue = collections.deque()
|
||||
|
||||
def enqueue(n):
|
||||
for _ in range(n):
|
||||
data = self.sample(batch_size)
|
||||
queue.append(data)
|
||||
|
||||
enqueue(queue_size)
|
||||
while queue:
|
||||
yield queue.popleft()
|
||||
enqueue(1)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Sequence[str] | None = None,
|
||||
capacity: int | None = None,
|
||||
image_augmentation_function: Callable | None = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device for sampling tensors. Defaults to "cuda:0".
|
||||
state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`.
|
||||
capacity (int | None): Buffer capacity. If None, uses dataset length.
|
||||
action_mask (Sequence[int] | None): Indices of action dimensions to keep.
|
||||
image_augmentation_function (Callable | None): Function for image augmentation.
|
||||
If None, uses default random shift with pad=4.
|
||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
|
||||
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with dataset transitions.
|
||||
"""
|
||||
if capacity is None:
|
||||
capacity = len(lerobot_dataset)
|
||||
|
||||
if capacity < len(lerobot_dataset):
|
||||
raise ValueError(
|
||||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||
)
|
||||
|
||||
# Create replay buffer with image augmentation and DrQ settings
|
||||
replay_buffer = cls(
|
||||
capacity=capacity,
|
||||
device=device,
|
||||
state_keys=state_keys,
|
||||
image_augmentation_function=image_augmentation_function,
|
||||
use_drq=use_drq,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=optimize_memory,
|
||||
)
|
||||
|
||||
# Convert dataset to transitions
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
|
||||
# Initialize the buffer with the first transition to set up storage tensors
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition["action"].to(device)
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
if (
|
||||
"complementary_info" in first_transition
|
||||
and first_transition["complementary_info"] is not None
|
||||
):
|
||||
first_complementary_info = {
|
||||
k: v.to(device) for k, v in first_transition["complementary_info"].items()
|
||||
}
|
||||
|
||||
replay_buffer._initialize_storage(
|
||||
state=first_state, action=first_action, complementary_info=first_complementary_info
|
||||
)
|
||||
|
||||
# Fill the buffer with all transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
for key, tensor in v.items():
|
||||
v[key] = tensor.to(storage_device)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(storage_device)
|
||||
|
||||
action = data["action"]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=action,
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
||||
complementary_info=data.get("complementary_info", None),
|
||||
)
|
||||
|
||||
return replay_buffer
|
||||
|
||||
def to_lerobot_dataset(
|
||||
self,
|
||||
repo_id: str,
|
||||
fps=1,
|
||||
root=None,
|
||||
task_name="from_replay_buffer",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object.
|
||||
"""
|
||||
if self.size == 0:
|
||||
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
|
||||
|
||||
# Create features dictionary for the dataset
|
||||
features = {
|
||||
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
|
||||
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
|
||||
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
|
||||
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
|
||||
"task_index": {"dtype": "int64", "shape": [1]},
|
||||
}
|
||||
|
||||
# Add "action"
|
||||
sample_action = self.actions[0]
|
||||
act_info = guess_feature_info(t=sample_action, name="action")
|
||||
features["action"] = act_info
|
||||
|
||||
# Add "reward" and "done"
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.states:
|
||||
sample_val = self.states[key][0]
|
||||
f_info = guess_feature_info(t=sample_val, name=key)
|
||||
features[key] = f_info
|
||||
|
||||
# Add complementary_info keys if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
sample_val = self.complementary_info[key][0]
|
||||
if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
|
||||
sample_val = sample_val.unsqueeze(0)
|
||||
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
|
||||
features[f"complementary_info.{key}"] = f_info
|
||||
|
||||
# Create an empty LeRobotDataset
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
root=root,
|
||||
robot_type=None,
|
||||
features=features,
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
# Start writing images if needed
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||
|
||||
# Convert transitions into episodes and frames
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for idx in range(self.size):
|
||||
actual_idx = (self.position - self.size + idx) % self.capacity
|
||||
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
for key in self.states:
|
||||
frame_dict[key] = self.states[key][actual_idx].cpu()
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
|
||||
# Add complementary_info if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
val = self.complementary_info[key][actual_idx]
|
||||
# Convert tensors to CPU
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.ndim == 0:
|
||||
val = val.unsqueeze(0)
|
||||
frame_dict[f"complementary_info.{key}"] = val.cpu()
|
||||
# Non-tensor values can be used directly
|
||||
else:
|
||||
frame_dict[f"complementary_info.{key}"] = val
|
||||
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict, task=task_name)
|
||||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
||||
lerobot_dataset.save_episode()
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index=episode_index
|
||||
)
|
||||
|
||||
# Save any remaining frames in the buffer
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode()
|
||||
|
||||
lerobot_dataset.stop_image_writer()
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Sequence[str] | None = None,
|
||||
) -> list[Transition]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
The dataset to convert. Each item in the dataset is expected to have
|
||||
at least the following keys:
|
||||
{
|
||||
"action": ...
|
||||
"next.reward": ...
|
||||
"next.done": ...
|
||||
"episode_index": ...
|
||||
}
|
||||
plus whatever your 'state_keys' specify.
|
||||
|
||||
state_keys (Sequence[str] | None):
|
||||
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||
will be kept as-is in the output transitions. E.g.
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
if state_keys is None:
|
||||
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
||||
|
||||
transitions = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
# Check if the dataset has "next.done" key
|
||||
sample = dataset[0]
|
||||
has_done_key = "next.done" in sample
|
||||
|
||||
# Check for complementary_info keys
|
||||
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
|
||||
has_complementary_info = len(complementary_info_keys) > 0
|
||||
|
||||
# If not, we need to infer it from episode boundaries
|
||||
if not has_done_key:
|
||||
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
current_state: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = current_sample[key]
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
|
||||
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
|
||||
if has_done_key:
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
else:
|
||||
# If this is the last frame or if next frame is in a different episode, mark as done
|
||||
done = False
|
||||
if i == num_frames - 1:
|
||||
done = True
|
||||
elif i < num_frames - 1:
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] != current_sample["episode_index"]:
|
||||
done = True
|
||||
|
||||
# TODO: (azouitine) Handle truncation (using the same value as done for now)
|
||||
truncated = done
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||
next_state = current_state # default
|
||||
if not done and (i < num_frames - 1):
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||
# Build next_state from the same keys
|
||||
next_state_data: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = next_sample[key]
|
||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- 5) Complementary info (if available) -----
|
||||
complementary_info = None
|
||||
if has_complementary_info:
|
||||
complementary_info = {}
|
||||
for key in complementary_info_keys:
|
||||
# Strip the "complementary_info." prefix to get the actual key
|
||||
clean_key = key[len("complementary_info.") :]
|
||||
val = current_sample[key]
|
||||
# Handle tensor and non-tensor values differently
|
||||
if isinstance(val, torch.Tensor):
|
||||
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
||||
else:
|
||||
# TODO: (azouitine) Check if it's necessary to convert to tensor
|
||||
# For non-tensor values, use directly
|
||||
complementary_info[clean_key] = val
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
|
||||
|
||||
# Utility function to guess shapes/dtypes from a tensor
|
||||
def guess_feature_info(t, name: str):
|
||||
"""
|
||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
|
||||
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
||||
Otherwise default to appropriate dtype for numeric.
|
||||
"""
|
||||
|
||||
shape = tuple(t.shape)
|
||||
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
||||
if len(shape) == 3 and shape[0] in [1, 3]:
|
||||
return {
|
||||
"dtype": "image",
|
||||
"shape": shape,
|
||||
}
|
||||
else:
|
||||
# Otherwise treat as numeric
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": shape,
|
||||
}
|
||||
|
||||
|
||||
def concatenate_batch_transitions(
|
||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""
|
||||
Concatenates two BatchTransition objects into one.
|
||||
|
||||
This function merges the right BatchTransition into the left one by concatenating
|
||||
all corresponding tensors along dimension 0. The operation modifies the left_batch_transitions
|
||||
in place and also returns it.
|
||||
|
||||
Args:
|
||||
left_batch_transitions (BatchTransition): The first batch to concatenate and the one
|
||||
that will be modified in place.
|
||||
right_batch_transition (BatchTransition): The second batch to append to the first one.
|
||||
|
||||
Returns:
|
||||
BatchTransition: The concatenated batch (same object as left_batch_transitions).
|
||||
|
||||
Warning:
|
||||
This function modifies the left_batch_transitions object in place.
|
||||
"""
|
||||
# Concatenate state fields
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["state"][key], right_batch_transition["state"][key]],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
|
||||
# Concatenate basic fields
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
|
||||
# Concatenate next_state fields
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
|
||||
# Concatenate done and truncated fields
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
left_batch_transitions["truncated"] = torch.cat(
|
||||
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Handle complementary_info
|
||||
left_info = left_batch_transitions.get("complementary_info")
|
||||
right_info = right_batch_transition.get("complementary_info")
|
||||
|
||||
# Only process if right_info exists
|
||||
if right_info is not None:
|
||||
# Initialize left complementary_info if needed
|
||||
if left_info is None:
|
||||
left_batch_transitions["complementary_info"] = right_info
|
||||
else:
|
||||
# Concatenate each field
|
||||
for key in right_info:
|
||||
if key in left_info:
|
||||
left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
|
||||
else:
|
||||
left_info[key] = right_info[key]
|
||||
|
||||
return left_batch_transitions
|
||||
@@ -28,6 +28,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
try:
|
||||
# Primary method to get the package version
|
||||
package_version = importlib.metadata.version(pkg_name)
|
||||
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
# Fallback method: Only for "torch" and versions containing "dev"
|
||||
if pkg_name == "torch":
|
||||
@@ -43,6 +44,9 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
except ImportError:
|
||||
# If the package can't be imported, it's not available
|
||||
package_exists = False
|
||||
elif pkg_name == "grpc":
|
||||
package = importlib.import_module(pkg_name)
|
||||
package_version = getattr(package, "__version__", "N/A")
|
||||
else:
|
||||
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||
package_exists = False
|
||||
|
||||
83
lerobot/common/utils/process.py
Normal file
83
lerobot/common/utils/process.py
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
|
||||
class ProcessSignalHandler:
|
||||
"""Utility class to attach graceful shutdown signal handlers.
|
||||
|
||||
The class exposes a shutdown_event attribute that is set when a shutdown
|
||||
signal is received. A counter tracks how many shutdown signals have been
|
||||
caught. On the second signal the process exits with status 1.
|
||||
"""
|
||||
|
||||
_SUPPORTED_SIGNALS = ("SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT")
|
||||
|
||||
def __init__(self, use_threads: bool, display_pid: bool = False):
|
||||
# TODO: Check if we can use Event from threading since Event from
|
||||
# multiprocessing is the a clone of threading.Event.
|
||||
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Event
|
||||
if use_threads:
|
||||
from threading import Event
|
||||
else:
|
||||
from multiprocessing import Event
|
||||
|
||||
self.shutdown_event = Event()
|
||||
self._counter: int = 0
|
||||
self._display_pid = display_pid
|
||||
|
||||
self._register_handlers()
|
||||
|
||||
@property
|
||||
def counter(self) -> int: # pragma: no cover – simple accessor
|
||||
"""Number of shutdown signals that have been intercepted."""
|
||||
return self._counter
|
||||
|
||||
def _register_handlers(self):
|
||||
"""Attach the internal _signal_handler to a subset of POSIX signals."""
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
pid_str = ""
|
||||
if self._display_pid:
|
||||
pid_str = f"[PID: {os.getpid()}]"
|
||||
logging.info(f"{pid_str} Shutdown signal {signum} received. Cleaning up…")
|
||||
self.shutdown_event.set()
|
||||
self._counter += 1
|
||||
|
||||
# On a second Ctrl-C (or any supported signal) force the exit to
|
||||
# mimic the previous behaviour while giving the caller one chance to
|
||||
# shutdown gracefully.
|
||||
# TODO: Investigate if we need it later
|
||||
if self._counter > 1:
|
||||
logging.info("Force shutdown")
|
||||
sys.exit(1)
|
||||
|
||||
for sig_name in self._SUPPORTED_SIGNALS:
|
||||
sig = getattr(signal, sig_name, None)
|
||||
if sig is None:
|
||||
# The signal is not available on this platform (Windows for
|
||||
# instance does not provide SIGHUP, SIGQUIT…). Skip it.
|
||||
continue
|
||||
try:
|
||||
signal.signal(sig, _signal_handler)
|
||||
except (ValueError, OSError): # pragma: no cover – unlikely but safe
|
||||
# Signal not supported or we are in a non-main thread.
|
||||
continue
|
||||
39
lerobot/common/utils/queue.py
Normal file
39
lerobot/common/utils/queue.py
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
|
||||
def get_last_item_from_queue(queue: Queue, block=True, timeout: float = 0.1) -> Any:
|
||||
if block:
|
||||
try:
|
||||
item = queue.get(timeout=timeout)
|
||||
except Empty:
|
||||
return None
|
||||
else:
|
||||
item = None
|
||||
|
||||
# Drain queue and keep only the most recent parameters
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
return item
|
||||
85
lerobot/common/utils/transition.py
Normal file
85
lerobot/common/utils/transition.py
Normal file
@@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
truncated: bool
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
device = torch.device(device)
|
||||
non_blocking = device.type == "cuda"
|
||||
|
||||
# Move state tensors to device
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to device
|
||||
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move reward and done if they are tensors
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move next_state tensors to device
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# Move complementary_info tensors if present
|
||||
if transition.get("complementary_info") is not None:
|
||||
for key, val in transition["complementary_info"].items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device="cpu"):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
"""
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
||||
else:
|
||||
return state_dict
|
||||
@@ -20,9 +20,11 @@ import platform
|
||||
import select
|
||||
import subprocess
|
||||
import sys
|
||||
from copy import copy
|
||||
import time
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -109,11 +111,17 @@ def is_amp_available(device: str):
|
||||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def init_logging():
|
||||
def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
|
||||
# NOTE: Display PID is useful for multi-process logging.
|
||||
if display_pid:
|
||||
pid_str = f"[PID: {os.getpid()}]"
|
||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
else:
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -127,6 +135,12 @@ def init_logging():
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
if log_file is not None:
|
||||
# Additionally write logs to file
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
@@ -247,3 +261,114 @@ def enter_pressed() -> bool:
|
||||
def move_cursor_up(lines):
|
||||
"""Move the cursor up by a specified number of lines."""
|
||||
print(f"\033[{lines}A", end="")
|
||||
|
||||
|
||||
class TimerManager:
|
||||
"""
|
||||
Lightweight utility to measure elapsed time.
|
||||
|
||||
Examples
|
||||
--------
|
||||
```python
|
||||
# Example 1: Using context manager
|
||||
timer = TimerManager("Policy", log=False)
|
||||
for _ in range(3):
|
||||
with timer:
|
||||
time.sleep(0.01)
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
|
||||
```python
|
||||
# Example 2: Using start/stop methods
|
||||
timer = TimerManager("Policy", log=False)
|
||||
timer.start()
|
||||
time.sleep(0.01)
|
||||
timer.stop()
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str = "Elapsed-time",
|
||||
log: bool = True,
|
||||
logger: logging.Logger | None = None,
|
||||
):
|
||||
self.label = label
|
||||
self.log = log
|
||||
self.logger = logger
|
||||
self._start: float | None = None
|
||||
self._history: list[float] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self.start()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
self._start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def stop(self) -> float:
|
||||
if self._start is None:
|
||||
raise RuntimeError("Timer was never started.")
|
||||
elapsed = time.perf_counter() - self._start
|
||||
self._history.append(elapsed)
|
||||
self._start = None
|
||||
if self.log:
|
||||
if self.logger is not None:
|
||||
self.logger.info(f"{self.label}: {elapsed:.6f} s")
|
||||
else:
|
||||
logging.info(f"{self.label}: {elapsed:.6f} s")
|
||||
return elapsed
|
||||
|
||||
def reset(self):
|
||||
self._history.clear()
|
||||
|
||||
@property
|
||||
def last(self) -> float:
|
||||
return self._history[-1] if self._history else 0.0
|
||||
|
||||
@property
|
||||
def avg(self) -> float:
|
||||
return mean(self._history) if self._history else 0.0
|
||||
|
||||
@property
|
||||
def total(self) -> float:
|
||||
return sum(self._history)
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self._history)
|
||||
|
||||
@property
|
||||
def history(self) -> list[float]:
|
||||
return deepcopy(self._history)
|
||||
|
||||
@property
|
||||
def fps_history(self) -> list[float]:
|
||||
return [1.0 / t for t in self._history]
|
||||
|
||||
@property
|
||||
def fps_last(self) -> float:
|
||||
return 0.0 if self.last == 0 else 1.0 / self.last
|
||||
|
||||
@property
|
||||
def fps_avg(self) -> float:
|
||||
return 0.0 if self.avg == 0 else 1.0 / self.avg
|
||||
|
||||
def percentile(self, p: float) -> float:
|
||||
"""
|
||||
Return the p-th percentile of recorded times.
|
||||
"""
|
||||
if not self._history:
|
||||
return 0.0
|
||||
return float(np.percentile(self._history, p))
|
||||
|
||||
def fps_percentile(self, p: float) -> float:
|
||||
"""
|
||||
FPS corresponding to the p-th percentile time.
|
||||
"""
|
||||
val = self.percentile(p)
|
||||
return 0.0 if val == 0 else 1.0 / val
|
||||
|
||||
@@ -30,9 +30,10 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"dataset:{cfg.dataset.repo_id}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.dataset is not None:
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
return lst if return_list else "-".join(lst)
|
||||
@@ -92,6 +93,12 @@ class WandBLogger:
|
||||
resume="must" if cfg.resume else None,
|
||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||
)
|
||||
run_id = wandb.run.id
|
||||
# NOTE: We will override the cfg.wandb.run_id with the wandb run id.
|
||||
# This is because we want to be able to resume the run from the wandb run id.
|
||||
cfg.wandb.run_id = run_id
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
@@ -108,9 +115,26 @@ class WandBLogger:
|
||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
def log_dict(
|
||||
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
|
||||
):
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
# NOTE: This is not simple. Wandb step must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
# multiple time steps is possible. For example, the interaction step with the environment,
|
||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||
# to log the correct step for each metric.
|
||||
if custom_step_key is not None:
|
||||
if self._wandb_custom_step_key is None:
|
||||
self._wandb_custom_step_key = set()
|
||||
new_custom_key = f"{mode}/{custom_step_key}"
|
||||
if new_custom_key not in self._wandb_custom_step_key:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
@@ -118,7 +142,18 @@ class WandBLogger:
|
||||
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||
self._wandb.log(data)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.robots import RobotConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlConfig(draccus.ChoiceRegistry):
|
||||
pass
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("calibrate")
|
||||
@dataclass
|
||||
class CalibrateControlConfig(ControlConfig):
|
||||
# List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`)
|
||||
arms: list[str] | None = None
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("teleoperate")
|
||||
@dataclass
|
||||
class TeleoperateControlConfig(ControlConfig):
|
||||
# Limit the maximum frames per second. By default, no limit.
|
||||
fps: int | None = None
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("record")
|
||||
@dataclass
|
||||
class RecordControlConfig(ControlConfig):
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | Path | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int | None = None
|
||||
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
||||
warmup_time_s: int | float = 10
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("control.policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("control.policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("replay")
|
||||
@dataclass
|
||||
class ReplayControlConfig(ControlConfig):
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str
|
||||
# Index of the episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the dataset fps.
|
||||
fps: int | None = None
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("remote_robot")
|
||||
@dataclass
|
||||
class RemoteRobotConfig(ControlConfig):
|
||||
log_interval: int = 100
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp)
|
||||
viewer_ip: str | None = None
|
||||
viewer_port: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlPipelineConfig:
|
||||
robot: RobotConfig
|
||||
control: ControlConfig
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["control.policy"]
|
||||
@@ -172,3 +172,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||
|
||||
@@ -23,6 +23,7 @@ class FeatureType(str, Enum):
|
||||
VISUAL = "VISUAL"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
|
||||
118
lerobot/scripts/find_joint_limits.py
Normal file
118
lerobot/scripts/find_joint_limits.py
Normal file
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Simple script to control a robot from teleoperation.
|
||||
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.scripts.server.find_joint_limits \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
```
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.model.kinematics import RobotKinematics
|
||||
from lerobot.common.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
)
|
||||
from lerobot.common.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
gamepad,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
so100_leader,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FindJointLimitsConfig:
|
||||
teleop: TeleoperatorConfig
|
||||
robot: RobotConfig
|
||||
# Limit the maximum frames per second. By default, no limit.
|
||||
teleop_time_s: float = 30
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig):
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
|
||||
teleop.connect()
|
||||
robot.connect()
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
robot_type = getattr(robot.config, "robot_type", "so101")
|
||||
if "so100" in robot_type or "so101" in robot_type:
|
||||
# Note to be compatible with the rest of the codebase,
|
||||
# we are using the new calibration method for so101 and so100
|
||||
robot_type = "so_new_calibration"
|
||||
kinematics = RobotKinematics(robot_type=robot_type)
|
||||
|
||||
# Initialize min/max values
|
||||
observation = robot.get_observation()
|
||||
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
|
||||
ee_pos = kinematics.forward_kinematics(joint_positions, frame="gripper_tip")[:3, 3]
|
||||
|
||||
max_pos = joint_positions.copy()
|
||||
min_pos = joint_positions.copy()
|
||||
max_ee = ee_pos.copy()
|
||||
min_ee = ee_pos.copy()
|
||||
|
||||
while True:
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
|
||||
observation = robot.get_observation()
|
||||
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
|
||||
ee_pos = kinematics.forward_kinematics(joint_positions, frame="gripper_tip")[:3, 3]
|
||||
|
||||
# Skip initial warmup period
|
||||
if (time.perf_counter() - start_episode_t) < 5:
|
||||
continue
|
||||
|
||||
# Update min/max values
|
||||
max_ee = np.maximum(max_ee, ee_pos)
|
||||
min_ee = np.minimum(min_ee, ee_pos)
|
||||
max_pos = np.maximum(max_pos, joint_positions)
|
||||
min_pos = np.minimum(min_pos, joint_positions)
|
||||
|
||||
if time.perf_counter() - start_episode_t > cfg.teleop_time_s:
|
||||
print(f"Max ee position {np.round(max_ee, 4).tolist()}")
|
||||
print(f"Min ee position {np.round(min_ee, 4).tolist()}")
|
||||
print(f"Max joint pos position {np.round(max_pos, 4).tolist()}")
|
||||
print(f"Min joint pos position {np.round(min_pos, 4).tolist()}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
find_joint_and_ee_bounds()
|
||||
709
lerobot/scripts/rl/actor.py
Normal file
709
lerobot/scripts/rl/actor.py
Normal file
@@ -0,0 +1,709 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Actor server runner for distributed HILSerl robot policy training.
|
||||
|
||||
This script implements the actor component of the distributed HILSerl architecture.
|
||||
It executes the policy in the robot environment, collects experience,
|
||||
and sends transitions to the learner server for policy updates.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Start an actor server for real robot training with human-in-the-loop intervention:
|
||||
```bash
|
||||
python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
|
||||
server is started before launching the actor.
|
||||
|
||||
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
|
||||
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
|
||||
reduce interventions as the policy improves.
|
||||
|
||||
**WORKFLOW**:
|
||||
1. Determine robot workspace bounds using `find_joint_limits.py`
|
||||
2. Record demonstrations with `gym_manipulator.py` in record mode
|
||||
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
|
||||
4. Start the learner server with the training configuration
|
||||
5. Start this actor server with the same configuration
|
||||
6. Use human interventions to guide policy learning
|
||||
|
||||
For more details on the complete HILSerl training workflow, see:
|
||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.cameras import opencv # noqa: F401
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robots import so100_follower # noqa: F401
|
||||
from lerobot.common.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.common.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.common.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.common.utils.process import ProcessSignalHandler
|
||||
from lerobot.common.utils.queue import get_last_item_from_queue
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.robot_utils import busy_wait
|
||||
from lerobot.common.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.scripts.rl import learner_service
|
||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
cfg.validate()
|
||||
display_pid = False
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
display_pid = True
|
||||
|
||||
# Create logs directory to ensure it exists
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=display_pid)
|
||||
logging.info(f"Actor logging initialized, writing to {log_file}")
|
||||
|
||||
is_threaded = use_threads(cfg)
|
||||
shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
logging.info("[ACTOR] Establishing connection with Learner")
|
||||
if not establish_learner_connection(learner_client, shutdown_event):
|
||||
logging.error("[ACTOR] Failed to establish connection with Learner")
|
||||
return
|
||||
|
||||
if not use_threads(cfg):
|
||||
# If we use multithreading, we can reuse the channel
|
||||
grpc_channel.close()
|
||||
grpc_channel = None
|
||||
|
||||
logging.info("[ACTOR] Connection with Learner established")
|
||||
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
|
||||
concurrency_entity = None
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
receive_policy_process = concurrency_entity(
|
||||
target=receive_policy,
|
||||
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process = concurrency_entity(
|
||||
target=send_transitions,
|
||||
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
interactions_process = concurrency_entity(
|
||||
target=send_interactions,
|
||||
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process.start()
|
||||
interactions_process.start()
|
||||
receive_policy_process.start()
|
||||
|
||||
act_with_policy(
|
||||
cfg=cfg,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
transitions_queue=transitions_queue,
|
||||
interactions_queue=interactions_queue,
|
||||
)
|
||||
logging.info("[ACTOR] Policy process joined")
|
||||
|
||||
logging.info("[ACTOR] Closing queues")
|
||||
transitions_queue.close()
|
||||
interactions_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
transitions_process.join()
|
||||
logging.info("[ACTOR] Transitions process joined")
|
||||
interactions_process.join()
|
||||
logging.info("[ACTOR] Interactions process joined")
|
||||
receive_policy_process.join()
|
||||
logging.info("[ACTOR] Receive policy process joined")
|
||||
|
||||
logging.info("[ACTOR] join queues")
|
||||
transitions_queue.cancel_join_thread()
|
||||
interactions_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def act_with_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
|
||||
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
|
||||
|
||||
Args:
|
||||
cfg: Configuration settings for the interaction process.
|
||||
shutdown_event: Event to check if the process should shutdown.
|
||||
parameters_queue: Queue to receive updated network parameters from the learner.
|
||||
transitions_queue: Queue to send transitions to the learner.
|
||||
interactions_queue: Queue to send interactions to the learner.
|
||||
"""
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor policy process logging initialized")
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
episode_intervention = False
|
||||
# Add counters for intervention rate calculation
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
policy_timer = TimerManager("Policy inference", log=False)
|
||||
|
||||
for interaction_step in range(cfg.policy.online_steps):
|
||||
start_time = time.perf_counter()
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
transitions=list_transition_to_send_to_learner,
|
||||
transitions_queue=transitions_queue,
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = get_frequency_stats(policy_timer)
|
||||
policy_timer.reset()
|
||||
|
||||
# Calculate intervention rate
|
||||
intervention_rate = 0.0
|
||||
if episode_total_steps > 0:
|
||||
intervention_rate = episode_intervention_steps / episode_total_steps
|
||||
|
||||
# Send episodic reward to the learner
|
||||
interactions_queue.put(
|
||||
python_object_to_bytes(
|
||||
{
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
"Episode intervention": int(episode_intervention),
|
||||
"Intervention rate": intervention_rate,
|
||||
**stats,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
obs, info = online_env.reset()
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub: services_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event, # type: ignore
|
||||
attempts: int = 30,
|
||||
):
|
||||
"""Establish a connection with the learner.
|
||||
|
||||
Args:
|
||||
stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
|
||||
shutdown_event (Event): The event to check if the connection should be established.
|
||||
attempts (int): The number of attempts to establish the connection.
|
||||
Returns:
|
||||
bool: True if the connection is established, False otherwise.
|
||||
"""
|
||||
for _ in range(attempts):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||
return False
|
||||
|
||||
# Force a connection attempt and check state
|
||||
try:
|
||||
logging.info("[ACTOR] Send ready message to Learner")
|
||||
if stub.Ready(services_pb2.Empty()) == services_pb2.Empty():
|
||||
return True
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
import json
|
||||
|
||||
"""
|
||||
Returns a client for the learner service.
|
||||
|
||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||
So we need to create only one client and reuse it.
|
||||
"""
|
||||
|
||||
service_config = {
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||
"maxBackoff": "2s", # Max wait time between retries
|
||||
"backoffMultiplier": 2, # Exponential backoff factor
|
||||
"retryableStatusCodes": [
|
||||
"UNAVAILABLE",
|
||||
"DEADLINE_EXCEEDED",
|
||||
], # Retries on network failures
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service_config_json = json.dumps(service_config)
|
||||
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.enable_retries", 1),
|
||||
("grpc.service_config", service_config_json),
|
||||
],
|
||||
)
|
||||
stub = services_pb2_grpc.LearnerServiceStub(channel)
|
||||
logging.info("[ACTOR] Learner service client created")
|
||||
return stub, channel
|
||||
|
||||
|
||||
def receive_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
"""Receive parameters from the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
|
||||
parameters_queue (Queue): The queue to receive the parameters.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
"""
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor receive policy process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
_ = ProcessSignalHandler(use_threads=False, display_pid=True)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
iterator = learner_client.StreamParameters(services_pb2.Empty())
|
||||
receive_bytes_in_chunks(
|
||||
iterator,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
log_prefix="[ACTOR] parameters",
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Received policy loop stopped")
|
||||
|
||||
|
||||
def send_transitions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
"""
|
||||
Sends transitions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Transition Data:
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor transitions process logging initialized")
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendTransitions(
|
||||
transitions_stream(
|
||||
shutdown_event, transitions_queue, cfg.policy.actor_learner_config.queue_get_timeout
|
||||
)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming transitions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Transitions process stopped")
|
||||
|
||||
|
||||
def send_interactions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
"""
|
||||
Sends interactions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Interaction Messages:
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor interactions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
_ = ProcessSignalHandler(use_threads=False, display_pid=True)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendInteractions(
|
||||
interactions_stream(
|
||||
shutdown_event, interactions_queue, cfg.policy.actor_learner_config.queue_get_timeout
|
||||
)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming interactions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=timeout)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Transition queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
||||
)
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: Event,
|
||||
interactions_queue: Queue,
|
||||
timeout: float, # type: ignore
|
||||
) -> services_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = interactions_queue.get(block=True, timeout=timeout)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Interaction queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message,
|
||||
services_pb2.InteractionMessage,
|
||||
log_prefix="[ACTOR] Send interactions",
|
||||
)
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
#################################################
|
||||
# Policy functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = bytes_to_state_dict(bytes_state_dict)
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||
|
||||
Args:
|
||||
transitions: List of transitions to send
|
||||
message_queue: Queue to send messages to learner
|
||||
chunk_size: Size of each chunk to send
|
||||
"""
|
||||
transition_to_send_to_learner = []
|
||||
for transition in transitions:
|
||||
tr = move_transition_to_device(transition=transition, device="cpu")
|
||||
for key, value in tr["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
|
||||
transition_to_send_to_learner.append(tr)
|
||||
|
||||
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||
|
||||
|
||||
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
|
||||
"""Get the frequency statistics of the policy.
|
||||
|
||||
Args:
|
||||
timer (TimerManager): The timer with collected metrics.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The frequency statistics of the policy.
|
||||
"""
|
||||
stats = {}
|
||||
if timer.count > 1:
|
||||
avg_fps = timer.fps_avg
|
||||
p90_fps = timer.fps_percentile(90)
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
|
||||
stats = {
|
||||
"Policy frequency [Hz]": avg_fps,
|
||||
"Policy frequency 90th-p [Hz]": p90_fps,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: TrainRLServerPipelineConfig, interaction_step: int):
|
||||
if policy_fps < cfg.env.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
|
||||
)
|
||||
|
||||
|
||||
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.actor == "threads"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
actor_cli()
|
||||
314
lerobot/scripts/rl/crop_dataset_roi.py
Normal file
314
lerobot/scripts/rl/crop_dataset_roi.py
Normal file
@@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import cv2
|
||||
|
||||
# import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def select_rect_roi(img):
|
||||
"""
|
||||
Allows the user to draw a rectangular ROI on the image.
|
||||
|
||||
The user must click and drag to draw the rectangle.
|
||||
- While dragging, the rectangle is dynamically drawn.
|
||||
- On mouse button release, the rectangle is fixed.
|
||||
- Press 'c' to confirm the selection.
|
||||
- Press 'r' to reset the selection.
|
||||
- Press ESC to cancel.
|
||||
|
||||
Returns:
|
||||
A tuple (top, left, height, width) representing the rectangular ROI,
|
||||
or None if no valid ROI is selected.
|
||||
"""
|
||||
# Create a working copy of the image
|
||||
clone = img.copy()
|
||||
working_img = clone.copy()
|
||||
|
||||
roi = None # Will store the final ROI as (top, left, height, width)
|
||||
drawing = False
|
||||
index_x, index_y = -1, -1 # Initial click coordinates
|
||||
|
||||
def mouse_callback(event, x, y, flags, param):
|
||||
nonlocal index_x, index_y, drawing, roi, working_img
|
||||
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Start drawing: record starting coordinates
|
||||
drawing = True
|
||||
index_x, index_y = x, y
|
||||
|
||||
elif event == cv2.EVENT_MOUSEMOVE:
|
||||
if drawing:
|
||||
# Compute the top-left and bottom-right corners regardless of drag direction
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
# Show a temporary image with the current rectangle drawn
|
||||
temp = working_img.copy()
|
||||
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", temp)
|
||||
|
||||
elif event == cv2.EVENT_LBUTTONUP:
|
||||
# Finish drawing
|
||||
drawing = False
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
height = bottom - top
|
||||
width = right - left
|
||||
roi = (top, left, height, width) # (top, left, height, width)
|
||||
# Draw the final rectangle on the working image and display it
|
||||
working_img = clone.copy()
|
||||
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
# Create the window and set the callback
|
||||
cv2.namedWindow("Select ROI")
|
||||
cv2.setMouseCallback("Select ROI", mouse_callback)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
print("Instructions for ROI selection:")
|
||||
print(" - Click and drag to draw a rectangular ROI.")
|
||||
print(" - Press 'c' to confirm the selection.")
|
||||
print(" - Press 'r' to reset and draw again.")
|
||||
print(" - Press ESC to cancel the selection.")
|
||||
|
||||
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
|
||||
while True:
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
# Confirm ROI if one has been drawn
|
||||
if key == ord("c") and roi is not None:
|
||||
break
|
||||
# Reset: clear the ROI and restore the original image
|
||||
elif key == ord("r"):
|
||||
working_img = clone.copy()
|
||||
roi = None
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
# Cancel selection for this image
|
||||
elif key == 27: # ESC key
|
||||
roi = None
|
||||
break
|
||||
|
||||
cv2.destroyWindow("Select ROI")
|
||||
return roi
|
||||
|
||||
|
||||
def select_square_roi_for_images(images: dict) -> dict:
|
||||
"""
|
||||
For each image in the provided dictionary, open a window to allow the user
|
||||
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
|
||||
(top, left, height, width) representing the ROI.
|
||||
|
||||
Parameters:
|
||||
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
|
||||
|
||||
Returns:
|
||||
dict: Mapping of image keys to the selected rectangular ROI.
|
||||
"""
|
||||
selected_rois = {}
|
||||
|
||||
for key, img in images.items():
|
||||
if img is None:
|
||||
print(f"Image for key '{key}' is None, skipping.")
|
||||
continue
|
||||
|
||||
print(f"\nSelect rectangular ROI for image with key: '{key}'")
|
||||
roi = select_rect_roi(img)
|
||||
|
||||
if roi is None:
|
||||
print(f"No valid ROI selected for '{key}'.")
|
||||
else:
|
||||
selected_rois[key] = roi
|
||||
print(f"ROI for '{key}': {roi}")
|
||||
|
||||
return selected_rois
|
||||
|
||||
|
||||
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
||||
"""
|
||||
Find the first row in the dataset and extract the image in order to be used for the crop.
|
||||
"""
|
||||
row = dataset[0]
|
||||
image_dict = {}
|
||||
for k in row:
|
||||
if "image" in k:
|
||||
image_dict[k] = deepcopy(row[k])
|
||||
return image_dict
|
||||
|
||||
|
||||
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset: LeRobotDataset,
|
||||
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
|
||||
new_repo_id: str,
|
||||
new_dataset_root: str,
|
||||
resize_size: Tuple[int, int] = (128, 128),
|
||||
push_to_hub: bool = False,
|
||||
task: str = "",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts an existing LeRobotDataset by iterating over its episodes and frames,
|
||||
applying cropping and resizing to image observations, and saving a new dataset
|
||||
with the transformed data.
|
||||
|
||||
Args:
|
||||
original_dataset (LeRobotDataset): The source dataset.
|
||||
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
||||
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
||||
new_repo_id (str): Repository id for the new dataset.
|
||||
new_dataset_root (str): The root directory where the new dataset will be written.
|
||||
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
Defaults to (128, 128).
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
|
||||
and resized.
|
||||
"""
|
||||
# 1. Create a new (empty) LeRobotDataset for writing.
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=new_repo_id,
|
||||
fps=original_dataset.fps,
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info["features"],
|
||||
use_videos=len(original_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Update the metadata for every image key that will be cropped:
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||
|
||||
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
|
||||
prev_episode_index = 0
|
||||
for frame_idx in tqdm(range(len(original_dataset))):
|
||||
frame = original_dataset[frame_idx]
|
||||
|
||||
# Create a copy of the frame to add to the new dataset
|
||||
new_frame = {}
|
||||
for key, value in frame.items():
|
||||
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"):
|
||||
continue
|
||||
if key in ("next.done", "next.reward"):
|
||||
# if not isinstance(value, str) and len(value.shape) == 0:
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
if key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(value, top, left, height, width)
|
||||
value = F.resize(cropped, resize_size)
|
||||
value = value.clamp(0, 1)
|
||||
|
||||
new_frame[key] = value
|
||||
|
||||
new_dataset.add_frame(new_frame, task=task)
|
||||
|
||||
if frame["episode_index"].item() != prev_episode_index:
|
||||
# Save the episode
|
||||
new_dataset.save_episode()
|
||||
prev_episode_index = frame["episode_index"].item()
|
||||
|
||||
# Save the last episode
|
||||
new_dataset.save_episode()
|
||||
|
||||
if push_to_hub:
|
||||
new_dataset.push_to_hub()
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="The repository id of the LeRobot dataset to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The root directory of the LeRobot dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-params-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the JSON file containing the ROIs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to push the new dataset to the hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="",
|
||||
help="The natural language task to describe the dataset.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path) as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# Print the selected rectangular ROIs
|
||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
for key, roi in rois.items():
|
||||
print(f"{key}: {roi}")
|
||||
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=new_repo_id,
|
||||
new_dataset_root=new_dataset_root,
|
||||
resize_size=(128, 128),
|
||||
push_to_hub=args.push_to_hub,
|
||||
task=args.task,
|
||||
)
|
||||
|
||||
meta_dir = new_dataset_root / "meta"
|
||||
meta_dir.mkdir(exist_ok=True)
|
||||
|
||||
with open(meta_dir / "crop_params.json", "w") as f:
|
||||
json.dump(rois, f, indent=4)
|
||||
2171
lerobot/scripts/rl/gym_manipulator.py
Normal file
2171
lerobot/scripts/rl/gym_manipulator.py
Normal file
File diff suppressed because it is too large
Load Diff
1206
lerobot/scripts/rl/learner.py
Normal file
1206
lerobot/scripts/rl/learner.py
Normal file
File diff suppressed because it is too large
Load Diff
118
lerobot/scripts/rl/learner_service.py
Normal file
118
lerobot/scripts/rl/learner_service.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
from lerobot.common.utils.queue import get_last_item_from_queue
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
SHUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
"""
|
||||
Implementation of the LearnerService gRPC service
|
||||
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
||||
check transport.proto for the gRPC service definition
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shutdown_event: Event, # type: ignore
|
||||
parameters_queue: Queue,
|
||||
seconds_between_pushes: float,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
queue_get_timeout: float = 0.001,
|
||||
):
|
||||
self.shutdown_event = shutdown_event
|
||||
self.parameters_queue = parameters_queue
|
||||
self.seconds_between_pushes = seconds_between_pushes
|
||||
self.transition_queue = transition_queue
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
self.queue_get_timeout = queue_get_timeout
|
||||
|
||||
def StreamParameters(self, request, context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
last_push_time = 0
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
time_since_last_push = time.time() - last_push_time
|
||||
if time_since_last_push < self.seconds_between_pushes:
|
||||
self.shutdown_event.wait(self.seconds_between_pushes - time_since_last_push)
|
||||
# Continue, because we could receive a shutdown event,
|
||||
# and it's checked in the while loop
|
||||
continue
|
||||
|
||||
logging.info("[LEARNER] Push parameters to the Actor")
|
||||
buffer = get_last_item_from_queue(
|
||||
self.parameters_queue, block=True, timeout=self.queue_get_timeout
|
||||
)
|
||||
|
||||
if buffer is None:
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
buffer,
|
||||
services_pb2.Parameters,
|
||||
log_prefix="[LEARNER] Sending parameters",
|
||||
silent=True,
|
||||
)
|
||||
|
||||
last_push_time = time.time()
|
||||
logging.info("[LEARNER] Parameters sent")
|
||||
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.transition_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] transitions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.interaction_message_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] interactions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
return services_pb2.Empty()
|
||||
@@ -58,7 +58,7 @@ from lerobot.common.utils.robot_utils import busy_wait
|
||||
from lerobot.common.utils.utils import init_logging, move_cursor_up
|
||||
from lerobot.common.utils.visualization_utils import _init_rerun
|
||||
|
||||
from .common.teleoperators import koch_leader, so100_leader, so101_leader # noqa: F401
|
||||
from .common.teleoperators import gamepad, koch_leader, so100_leader, so101_leader # noqa: F401
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user