diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index d83df15a..103cc7ec 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -182,15 +182,15 @@ class EEActionSpaceConfig: y_step_size: float z_step_size: float bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds - use_gamepad: bool = False + control_mode: str = "gamepad" @dataclass class EnvWrapperConfig: """Configuration for environment wrappers.""" + ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig) display_cameras: bool = False - use_relative_joint_positions: bool = True add_joint_velocity_to_observation: bool = False add_current_to_observation: bool = False add_ee_pose_to_observation: bool = False @@ -199,13 +199,10 @@ class EnvWrapperConfig: control_time_s: float = 20.0 fixed_reset_joint_positions: Optional[Any] = None reset_time_s: float = 5.0 - joint_masking_action_space: Optional[Any] = None - ee_action_space_params: Optional[EEActionSpaceConfig] = None use_gripper: bool = False gripper_quantization_threshold: float | None = 0.8 gripper_penalty: float = 0.0 gripper_penalty_in_reward: bool = False - open_gripper_on_reset: bool = False @EnvConfig.register_subclass(name="gym_manipulator") diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index bef13143..480356c7 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -308,13 +308,13 @@ def reset_environment(robot, events, reset_time_s, fps): ) -def reset_follower_position(robot: Robot, target_position): - current_position = robot.follower_arms["main"].read("Present_Position") +def reset_follower_position(robot_arm, target_position): + current_position = robot_arm.read("Present_Position") trajectory = torch.from_numpy( np.linspace(current_position, target_position, 50) ) # NOTE: 30 is just an arbitrary number for pose in trajectory: - robot.send_action(pose) + robot_arm.write("Goal_Position", pose) busy_wait(0.015) diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index 8d66dae2..e940b442 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig): leader_arms: dict[str, MotorsBusConfig] = field( default_factory=lambda: { "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem58760433331", + port="/dev/tty.usbmodem58760431091", motors={ # name: (index, model) "shoulder_pan": [1, "sts3215"], @@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig): follower_arms: dict[str, MotorsBusConfig] = field( default_factory=lambda: { "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem58760431631", + port="/dev/tty.usbmodem585A0076891", motors={ # name: (index, model) "shoulder_pan": [1, "sts3215"], diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 658aefd5..3daea98d 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -167,7 +167,7 @@ from lerobot.common.robot_devices.control_utils import ( warmup_record, ) from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config -from lerobot.common.robot_devices.utils import safe_disconnect +from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect from lerobot.common.utils.utils import has_method, init_logging, log_say from lerobot.configs import parser @@ -276,6 +276,7 @@ def record( if not robot.is_connected: robot.connect() + listener, events = init_keyboard_listener() # Execute a few seconds without recording to: @@ -284,14 +285,7 @@ def record( # 3. place the cameras windows on screen enable_teleoperation = policy is None log_say("Warmup record", cfg.play_sounds) - warmup_record( - robot, - events, - enable_teleoperation, - cfg.warmup_time_s, - cfg.display_data, - cfg.fps, - ) + warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps) if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() @@ -356,6 +350,7 @@ def replay( dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) actions = dataset.hf_dataset.select_columns("action") + if not robot.is_connected: robot.connect() @@ -366,6 +361,9 @@ def replay( action = actions[idx]["action"] robot.send_action(action) + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / cfg.fps - dt_s) + dt_s = time.perf_counter() - start_episode_t log_control_info(robot, dt_s, fps=cfg.fps) diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 3bf30b26..109bb2c4 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -20,13 +20,11 @@ from functools import lru_cache from queue import Empty from statistics import mean, quantiles -# from lerobot.scripts.eval import eval_policy import grpc import torch from torch import nn from torch.multiprocessing import Event, Queue -# TODO: Remove the import of maniskill from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.robot_devices.utils import busy_wait @@ -39,20 +37,21 @@ from lerobot.common.utils.utils import ( from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service -from lerobot.scripts.server.buffer import ( - Transition, - bytes_to_state_dict, - move_state_dict_to_device, - move_transition_to_device, - python_object_to_bytes, - transitions_to_bytes, -) +from lerobot.scripts.server.buffer import Transition from lerobot.scripts.server.gym_manipulator import make_robot_env from lerobot.scripts.server.network_utils import ( + bytes_to_state_dict, + python_object_to_bytes, receive_bytes_in_chunks, send_bytes_in_chunks, + transitions_to_bytes, +) +from lerobot.scripts.server.utils import ( + get_last_item_from_queue, + move_state_dict_to_device, + move_transition_to_device, + setup_process_handlers, ) -from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers ACTOR_SHUTDOWN_TIMEOUT = 30 @@ -134,21 +133,8 @@ def actor_cli(cfg: TrainPipelineConfig): interactions_process.start() receive_policy_process.start() - # HACK: FOR MANISKILL we do not have a reward classifier - # TODO: Remove this once we merge into main - reward_classifier = None - # if ( - # cfg.env.reward_classifier["pretrained_path"] is not None - # and cfg.env.reward_classifier["config_path"] is not None - # ): - # reward_classifier = get_classifier( - # pretrained_path=cfg.env.reward_classifier["pretrained_path"], - # config_path=cfg.env.reward_classifier["config_path"], - # ) - act_with_policy( cfg=cfg, - reward_classifier=reward_classifier, shutdown_event=shutdown_event, parameters_queue=parameters_queue, transitions_queue=transitions_queue, @@ -183,7 +169,6 @@ def actor_cli(cfg: TrainPipelineConfig): def act_with_policy( cfg: TrainPipelineConfig, - reward_classifier: nn.Module, shutdown_event: any, # Event, parameters_queue: Queue, transitions_queue: Queue, @@ -197,7 +182,6 @@ def act_with_policy( Args: cfg: Configuration settings for the interaction process. - reward_classifier: Reward classifier to use for the interaction process. shutdown_event: Event to check if the process should shutdown. parameters_queue: Queue to receive updated network parameters from the learner. transitions_queue: Queue to send transitions to the learner. @@ -262,16 +246,10 @@ def act_with_policy( log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) - next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy()) else: - # TODO (azouitine): Make a custom space for torch tensor action = online_env.action_space.sample() - next_obs, reward, done, truncated, info = online_env.step(action) - # HACK: We have only one env but we want to batch it, it will be resolved with the torch box - action = ( - torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0) - ) + next_obs, reward, done, truncated, info = online_env.step(action) sum_reward_episode += float(reward) # Increment total steps counter for intervention rate @@ -286,11 +264,6 @@ def act_with_policy( # Increment intervention steps counter episode_intervention_steps += 1 - # Check for NaN values in observations - for key, tensor in obs.items(): - if torch.isnan(tensor).any(): - logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}") - list_transition_to_send_to_learner.append( Transition( state=obs, diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index a54016eb..8da5cecf 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -15,26 +15,15 @@ # limitations under the License. import functools -import io -import pickle # nosec B403: Safe usage of pickle from contextlib import suppress -from typing import Any, Callable, Optional, Sequence, TypedDict +from typing import Callable, Optional, Sequence, TypedDict import torch import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset - - -class Transition(TypedDict): - state: dict[str, torch.Tensor] - action: torch.Tensor - reward: float - next_state: dict[str, torch.Tensor] - done: bool - truncated: bool - complementary_info: dict[str, torch.Tensor | float | int] | None = None +from lerobot.scripts.server.utils import Transition class BatchTransition(TypedDict): @@ -47,103 +36,6 @@ class BatchTransition(TypedDict): complementary_info: dict[str, torch.Tensor | float | int] | None = None -def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: - device = torch.device(device) - non_blocking = device.type == "cuda" - - # Move state tensors to device - transition["state"] = { - key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items() - } - - # Move action to device - transition["action"] = transition["action"].to(device, non_blocking=non_blocking) - - # Move reward and done if they are tensors - if isinstance(transition["reward"], torch.Tensor): - transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking) - - if isinstance(transition["done"], torch.Tensor): - transition["done"] = transition["done"].to(device, non_blocking=non_blocking) - - if isinstance(transition["truncated"], torch.Tensor): - transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking) - - # Move next_state tensors to device - transition["next_state"] = { - key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items() - } - - # Move complementary_info tensors if present - if transition.get("complementary_info") is not None: - for key, val in transition["complementary_info"].items(): - if isinstance(val, torch.Tensor): - transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) - elif isinstance(val, (int, float, bool)): - transition["complementary_info"][key] = torch.tensor(val, device=device) - else: - raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") - return transition - - -def move_state_dict_to_device(state_dict, device="cpu"): - """ - Recursively move all tensors in a (potentially) nested - dict/list/tuple structure to the CPU. - """ - if isinstance(state_dict, torch.Tensor): - return state_dict.to(device) - elif isinstance(state_dict, dict): - return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()} - elif isinstance(state_dict, list): - return [move_state_dict_to_device(v, device=device) for v in state_dict] - elif isinstance(state_dict, tuple): - return tuple(move_state_dict_to_device(v, device=device) for v in state_dict) - else: - return state_dict - - -def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: - """Convert model state dict to flat array for transmission""" - buffer = io.BytesIO() - - torch.save(state_dict, buffer) - - return buffer.getvalue() - - -def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: - buffer = io.BytesIO(buffer) - buffer.seek(0) - return torch.load(buffer) # nosec B614: Safe usage of torch.load - - -def python_object_to_bytes(python_object: Any) -> bytes: - return pickle.dumps(python_object) - - -def bytes_to_python_object(buffer: bytes) -> Any: - buffer = io.BytesIO(buffer) - buffer.seek(0) - obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load - # Add validation checks here - return obj - - -def bytes_to_transitions(buffer: bytes) -> list[Transition]: - buffer = io.BytesIO(buffer) - buffer.seek(0) - transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load - # Add validation checks here - return transitions - - -def transitions_to_bytes(transitions: list[Transition]) -> bytes: - buffer = io.BytesIO() - torch.save(transitions, buffer) - return buffer.getvalue() - - def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: """ Perform a per-image random crop over a batch of images in a vectorized way. @@ -514,7 +406,6 @@ class ReplayBuffer: device: str = "cuda:0", state_keys: Optional[Sequence[str]] = None, capacity: Optional[int] = None, - action_mask: Optional[Sequence[int]] = None, image_augmentation_function: Optional[Callable] = None, use_drq: bool = True, storage_device: str = "cpu", @@ -566,13 +457,6 @@ class ReplayBuffer: first_state = {k: v.to(device) for k, v in first_transition["state"].items()} first_action = first_transition["action"].to(device) - # Apply action mask/delta if needed - if action_mask is not None: - if first_action.dim() == 1: - first_action = first_action[action_mask] - else: - first_action = first_action[:, action_mask] - # Get complementary info if available first_complementary_info = None if ( @@ -597,8 +481,6 @@ class ReplayBuffer: data[k] = v.to(storage_device) action = data["action"] - if action_mask is not None: - action = action[action_mask] if action.dim() == 1 else action[:, action_mask] replay_buffer.add( state=data["state"], diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 12fb7459..394466c5 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -524,7 +524,9 @@ def teleoperate_inverse_kinematics_with_leader(robot, fps=10): leader_ee = kinematics.fk_gripper_tip(leader_joint_positions) desired_ee_pos = leader_ee - target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True) + target_joint_state = kinematics.ik( + joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip + ) robot.send_action(torch.from_numpy(target_joint_state)) logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}") busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) @@ -544,6 +546,8 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10): initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions) desired_ee_pos = np.diag(np.ones(4)) + joint_positions = robot.follower_arms["main"].read("Present_Position") + fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions) while time.perf_counter() - timestep < 60.0: loop_start_time = time.perf_counter() @@ -561,25 +565,26 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10): # Calculate delta between leader and follower end-effectors # Scaling factor can be adjusted for sensitivity scaling_factor = 1.0 - ee_delta = (leader_ee - initial_leader_ee) * scaling_factor + ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05) # Apply delta to current position - desired_ee_pos[0, 3] = current_ee_pos[0, 3] + ee_delta[0, 3] - desired_ee_pos[1, 3] = current_ee_pos[1, 3] + ee_delta[1, 3] - desired_ee_pos[2, 3] = current_ee_pos[2, 3] + ee_delta[2, 3] + desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0 + desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0 + desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3] - if np.any(np.abs(ee_delta[:3, 3]) > 0.01): - # Compute joint targets via inverse kinematics - target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True) + # Compute joint targets via inverse kinematics + target_joint_state = kinematics.ik( + joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip + ) - initial_leader_ee = leader_ee.copy() + initial_leader_ee = leader_ee.copy() - # Send command to robot - robot.send_action(torch.from_numpy(target_joint_state)) + # Send command to robot + robot.send_action(torch.from_numpy(target_joint_state)) - # Logging - logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}") - logging.info(f"Delta EE: {ee_delta[:3, 3]}") + # Logging + logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}") + logging.info(f"Delta EE: {ee_delta[:3, 3]}") busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) @@ -715,8 +720,8 @@ if __name__ == "__main__": "gamepad", "keyboard_gym", "gamepad_gym", + "leader_delta", "leader", - "leader_abs", ], help="Control mode to use", ) @@ -768,11 +773,11 @@ if __name__ == "__main__": env = make_robot_env(cfg, robot) teleoperate_gym_env(env, controller, fps=cfg.fps) - elif args.mode == "leader": + elif args.mode == "leader_delta": # Leader-follower modes don't use controllers teleoperate_delta_inverse_kinematics_with_leader(robot) - elif args.mode == "leader_abs": + elif args.mode == "leader": teleoperate_inverse_kinematics_with_leader(robot) finally: diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/server/find_joint_limits.py index d40caf5a..39976857 100644 --- a/lerobot/scripts/server/find_joint_limits.py +++ b/lerobot/scripts/server/find_joint_limits.py @@ -11,7 +11,7 @@ from lerobot.configs import parser from lerobot.scripts.server.kinematics import RobotKinematics follower_port = "/dev/tty.usbmodem58760431631" -leader_port = "/dev/tty.usbmodem58760433331" +leader_port = "/dev/tty.usbmodem585A0077921" def find_joint_bounds( diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index b6dcb07f..9bf215ab 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1,8 +1,9 @@ import logging import sys import time +from collections import deque from threading import Lock -from typing import Annotated, Any, Dict, Tuple +from typing import Annotated, Any, Dict, Sequence, Tuple import gymnasium as gym import numpy as np @@ -10,7 +11,6 @@ import torch import torchvision.transforms.functional as F # noqa: N812 from lerobot.common.envs.configs import EnvConfig -from lerobot.common.envs.utils import preprocess_observation from lerobot.common.robot_devices.control_utils import ( busy_wait, is_headless, @@ -25,35 +25,28 @@ logging.basicConfig(level=logging.INFO) MAX_GRIPPER_COMMAND = 40 -class HILSerlRobotEnv(gym.Env): +class RobotEnv(gym.Env): """ Gym-compatible environment for evaluating robotic control policies with integrated human intervention. This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta) and absolute joint position commands and automatically configures its observation and action spaces based on the robot's sensors and configuration. - - The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during - each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag - `is_intervention`. """ def __init__( self, robot, - use_delta_action_space: bool = True, display_cameras: bool = False, ): """ - Initialize the HILSerlRobotEnv environment. + Initialize the RobotEnv environment. The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup supports both relative (delta) adjustments and absolute joint positions for controlling the robot. cfg. robot: The robot interface object used to connect and interact with the physical robot. - use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute - joint positions are used. display_cameras (bool): If True, the robot's camera feeds will be displayed during execution. """ super().__init__() @@ -65,33 +58,12 @@ class HILSerlRobotEnv(gym.Env): if not self.robot.is_connected: self.robot.connect() - self.initial_follower_position = robot.follower_arms["main"].read("Present_Position") - # Episode tracking. self.current_step = 0 self.episode_data = None - self.use_delta_action_space = use_delta_action_space self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") - # Retrieve the size of the joint position interval bound. - - self.relative_bounds_size = None - # ( - # ( - # self.robot.config.joint_position_relative_bounds["max"] - # - self.robot.config.joint_position_relative_bounds["min"] - # ) - # if self.robot.config.joint_position_relative_bounds is not None - # else None - # ) - self.robot.config.joint_position_relative_bounds = None - - self.robot.config.max_relative_target = ( - self.relative_bounds_size.float() if self.relative_bounds_size is not None else None - ) - - # Dynamically configure the observation and action spaces. self._setup_spaces() def _setup_spaces(self): @@ -103,10 +75,8 @@ class HILSerlRobotEnv(gym.Env): - For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range. Action Space: - - The action space is defined as a Tuple where: - • The first element is a Box space representing joint position commands. It is defined as relative (delta) - or absolute, based on the configuration. - • The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation). + - The action space is defined as a Box space representing joint position commands. It is defined as relative (delta) + or absolute, based on the configuration. """ example_obs = self.robot.capture_observation() @@ -127,41 +97,15 @@ class HILSerlRobotEnv(gym.Env): # Define the action space for joint positions along with setting an intervention flag. action_dim = len(self.robot.follower_arms["main"].read("Present_Position")) - if self.use_delta_action_space: - bounds = ( - self.relative_bounds_size - if self.relative_bounds_size is not None - else np.ones(action_dim) * 1000 - ) - action_space_robot = gym.spaces.Box( - low=-bounds, - high=bounds, - shape=(action_dim,), - dtype=np.float32, - ) - else: - bounds_min = ( - self.robot.config.joint_position_relative_bounds["min"].cpu().numpy() - if self.robot.config.joint_position_relative_bounds is not None - else np.ones(action_dim) * -1000 - ) - bounds_max = ( - self.robot.config.joint_position_relative_bounds["max"].cpu().numpy() - if self.robot.config.joint_position_relative_bounds is not None - else np.ones(action_dim) * 1000 - ) - action_space_robot = gym.spaces.Box( - low=bounds_min, - high=bounds_max, - shape=(action_dim,), - dtype=np.float32, - ) + bounds = {} + bounds["min"] = np.ones(action_dim) * -1000 + bounds["max"] = np.ones(action_dim) * 1000 - self.action_space = gym.spaces.Tuple( - ( - action_space_robot, - gym.spaces.Discrete(2), - ), + self.action_space = gym.spaces.Box( + low=bounds["min"], + high=bounds["max"], + shape=(action_dim,), + dtype=np.float32, ) def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: @@ -187,27 +131,17 @@ class HILSerlRobotEnv(gym.Env): self.current_step = 0 self.episode_data = None - return observation, {} + return observation, {"is_intervention": False} - def step( - self, action: Tuple[np.ndarray, bool] - ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: + def step(self, action) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: """ Execute a single step within the environment using the specified action. - The provided action is a tuple comprised of: - • A policy action (joint position commands) that may be either in absolute values or as a delta. - • A boolean flag indicating whether teleoperation (human intervention) should be used for this step. - - Behavior: - - When the intervention flag is False, the environment processes and sends the policy action to the robot. - - When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted - to relative change based on the current joint positions. + The provided action is processed and sent to the robot as joint position commands + that may be either absolute values or deltas based on the environment configuration. cfg. - action (tuple): A tuple with two elements: - - policy_action (np.ndarray or torch.Tensor): The commanded joint positions. - - intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input. + action (np.ndarray or torch.Tensor): The commanded joint positions. Returns: tuple: A tuple containing: @@ -216,48 +150,11 @@ class HILSerlRobotEnv(gym.Env): - terminated (bool): True if the episode has reached a terminal state. - truncated (bool): True if the episode was truncated (e.g., time constraints). - info (dict): Additional debugging information including: - ◦ "action_intervention": The teleop action if intervention was used. - ◦ "is_intervention": Flag indicating whether teleoperation was employed. """ - policy_action, intervention_bool = action - teleop_action = None self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") - if isinstance(policy_action, torch.Tensor): - policy_action = policy_action.cpu().numpy() - policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high) - if not intervention_bool: - if self.use_delta_action_space: - target_joint_positions = self.current_joint_positions + self.delta * policy_action - else: - target_joint_positions = policy_action - self.robot.send_action(torch.from_numpy(target_joint_positions)) - observation = self.robot.capture_observation() - else: - observation, teleop_action = self.robot.teleop_step(record_data=True) - teleop_action = teleop_action["action"] # Convert tensor to appropriate format - - # When applying the delta action space, convert teleop absolute values to relative differences. - if self.use_delta_action_space: - teleop_action = (teleop_action - self.current_joint_positions) / self.delta - if self.relative_bounds_size is not None and ( - torch.any(teleop_action < -self.relative_bounds_size) - and torch.any(teleop_action > self.relative_bounds_size) - ): - logging.debug( - f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n" - f"lower bounds condition {teleop_action < -self.relative_bounds_size}\n" - f"upper bounds condition {teleop_action > self.relative_bounds_size}" - ) - - teleop_action = torch.clamp( - teleop_action, - -self.relative_bounds_size, - self.relative_bounds_size, - ) - # NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action. - if teleop_action.dim() == 1: - teleop_action = teleop_action.unsqueeze(0) + self.robot.send_action(torch.from_numpy(action)) + observation = self.robot.capture_observation() if self.display_cameras: self.render() @@ -273,10 +170,7 @@ class HILSerlRobotEnv(gym.Env): reward, terminated, truncated, - { - "action_intervention": teleop_action, - "is_intervention": teleop_action is not None, - }, + {"is_intervention": False}, ) def render(self): @@ -368,19 +262,6 @@ class AddCurrentToObservation(gym.ObservationWrapper): return observation -class ActionRepeatWrapper(gym.Wrapper): - def __init__(self, env, nb_repeat: int = 1): - super().__init__(env) - self.nb_repeat = nb_repeat - - def step(self, action): - for _ in range(self.nb_repeat): - obs, reward, done, truncated, info = self.env.step(action) - if done or truncated: - break - return obs, reward, done, truncated, info - - class RewardWrapper(gym.Wrapper): def __init__(self, env, reward_classifier, device: torch.device = "cuda"): """ @@ -393,8 +274,6 @@ class RewardWrapper(gym.Wrapper): """ self.env = env - if isinstance(device, str): - device = torch.device(device) self.device = device self.reward_classifier = torch.compile(reward_classifier) @@ -426,77 +305,6 @@ class RewardWrapper(gym.Wrapper): return self.env.reset(seed=seed, options=options) -class JointMaskingActionSpace(gym.Wrapper): - def __init__(self, env, mask): - """ - Wrapper to mask out dimensions of the action space. - - cfg. - env: The environment to wrap - mask: Binary mask array where 0 indicates dimensions to remove - """ - super().__init__(env) - - # Validate mask matches action space - - # Keep only dimensions where mask is 1 - self.active_dims = np.where(mask)[0] - - if isinstance(env.action_space, gym.spaces.Box): - if len(mask) != env.action_space.shape[0]: - raise ValueError("Mask length must match action space dimensions") - low = env.action_space.low[self.active_dims] - high = env.action_space.high[self.active_dims] - self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype) - - if isinstance(env.action_space, gym.spaces.Tuple): - if len(mask) != env.action_space[0].shape[0]: - raise ValueError("Mask length must match action space 0 dimensions") - - low = env.action_space[0].low[self.active_dims] - high = env.action_space[0].high[self.active_dims] - action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype) - self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1])) - # Create new action space with masked dimensions - - def action(self, action): - """ - Convert masked action back to full action space. - - cfg. - action: Action in masked space. For Tuple spaces, the first element is masked. - - Returns: - Action in original space with masked dims set to 0. - """ - - # Determine whether we are handling a Tuple space or a Box. - if isinstance(self.env.action_space, gym.spaces.Tuple): - # Extract the masked component from the tuple. - masked_action = action[0] if isinstance(action, tuple) else action - # Create a full action for the Box element. - full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype) - full_box_action[self.active_dims] = masked_action - # Return a tuple with the reconstructed Box action and the unchanged remainder. - return (full_box_action, action[1]) - else: - # For Box action spaces. - masked_action = action if not isinstance(action, tuple) else action[0] - full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype) - full_action[self.active_dims] = masked_action - return full_action - - def step(self, action): - action = self.action(action) - obs, reward, terminated, truncated, info = self.env.step(action) - if "action_intervention" in info and info["action_intervention"] is not None: - if info["action_intervention"].dim() == 1: - info["action_intervention"] = info["action_intervention"][self.active_dims] - else: - info["action_intervention"] = info["action_intervention"][:, self.active_dims] - return obs, reward, terminated, truncated, info - - class TimeLimitWrapper(gym.Wrapper): def __init__(self, env, control_time_s, fps): self.env = env @@ -565,32 +373,21 @@ class ImageCropResizeWrapper(gym.Wrapper): flattened_spatial_dims = obs[k].view(batch_size, channels, -1) # Calculate standard deviation across spatial dimensions (H, W) - std_per_channel = torch.std(flattened_spatial_dims, dim=2) - # If any channel has std=0, all pixels in that channel have the same value + # This is helpful if one camera mistakenly covered or the image is black + std_per_channel = torch.std(flattened_spatial_dims, dim=2) if (std_per_channel <= 0.02).any(): logging.warning( f"Potential hardware issue detected: All pixels have the same value in observation {k}" ) - # Check for NaNs before processing - if torch.isnan(obs[k]).any(): - logging.error(f"NaN values detected in observation {k} before crop and resize") if device == torch.device("mps:0"): obs[k] = obs[k].cpu() obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) obs[k] = F.resize(obs[k], self.resize_size) - # TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1] + # TODO (michel-aractingi): Bug in resize, it returns values outside [0, 1] obs[k] = obs[k].clamp(0.0, 1.0) - - # import cv2 - # cv2.imwrite(f"tmp_img/{k}.jpg", obs[k].squeeze(0).permute(1, 2, 0).cpu().numpy() * 255) - - # Check for NaNs after processing - if torch.isnan(obs[k]).any(): - logging.error(f"NaN values detected in observation {k} after crop and resize") - obs[k] = obs[k].to(device) return obs, reward, terminated, truncated, info @@ -609,16 +406,17 @@ class ImageCropResizeWrapper(gym.Wrapper): class ConvertToLeRobotObservation(gym.ObservationWrapper): - def __init__(self, env, device): + def __init__(self, env, device: str = "cpu"): super().__init__(env) - if isinstance(device, str): - device = torch.device(device) - self.device = device + self.device = torch.device(device) def observation(self, observation): - observation = preprocess_observation(observation) - + for key in observation: + observation[key] = observation[key].float() + if "image" in key: + observation[key] = observation[key].permute(2, 0, 1) + observation[key] /= 255.0 observation = { key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") for key in observation @@ -627,154 +425,30 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper): return observation -class KeyboardInterfaceWrapper(gym.Wrapper): - def __init__(self, env): - super().__init__(env) - self.listener = None - self.events = { - "exit_early": False, - "pause_policy": False, - "reset_env": False, - "human_intervention_step": False, - "episode_success": False, - } - self.event_lock = Lock() # Thread-safe access to events - self._init_keyboard_listener() - - def _init_keyboard_listener(self): - """Initialize keyboard listener if not in headless mode""" - - if is_headless(): - logging.warning( - "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." - ) - return - try: - from pynput import keyboard - - def on_press(key): - with self.event_lock: - try: - if key == keyboard.Key.right or key == keyboard.Key.esc: - print("Right arrow key pressed. Exiting loop...") - self.events["exit_early"] = True - return - if hasattr(key, "char") and key.char == "s": - print("Key 's' pressed. Episode success triggered.") - self.events["episode_success"] = True - return - if key == keyboard.Key.space and not self.events["exit_early"]: - if not self.events["pause_policy"]: - print( - "Space key pressed. Human intervention required.\n" - "Place the leader in similar pose to the follower and press space again." - ) - self.events["pause_policy"] = True - log_say( - "Human intervention stage. Get ready to take over.", - play_sounds=True, - ) - return - if self.events["pause_policy"] and not self.events["human_intervention_step"]: - self.events["human_intervention_step"] = True - print("Space key pressed. Human intervention starting.") - log_say("Starting human intervention.", play_sounds=True) - return - if self.events["pause_policy"] and self.events["human_intervention_step"]: - self.events["pause_policy"] = False - self.events["human_intervention_step"] = False - print("Space key pressed for a third time.") - log_say("Continuing with policy actions.", play_sounds=True) - return - except Exception as e: - print(f"Error handling key press: {e}") - - self.listener = keyboard.Listener(on_press=on_press) - self.listener.start() - except ImportError: - logging.warning("Could not import pynput. Keyboard interface will not be available.") - self.listener = None - - def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: - is_intervention = False - terminated_by_keyboard = False - - # Extract policy_action if needed - if isinstance(self.env.action_space, gym.spaces.Tuple): - policy_action = action[0] - - # Check the event flags without holding the lock for too long. - with self.event_lock: - if self.events["exit_early"]: - terminated_by_keyboard = True - pause_policy = self.events["pause_policy"] - - if pause_policy: - # Now, wait for human_intervention_step without holding the lock - while True: - with self.event_lock: - if self.events["human_intervention_step"]: - is_intervention = True - break - time.sleep(0.1) # Check more frequently if desired - - # Execute the step in the underlying environment - obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention)) - - # Override reward and termination if episode success event triggered - with self.event_lock: - if self.events["episode_success"]: - reward = 1 - terminated_by_keyboard = True - - return obs, reward, terminated or terminated_by_keyboard, truncated, info - - def reset(self, **kwargs) -> Tuple[Any, Dict]: - """ - Reset the environment and clear any pending events - """ - with self.event_lock: - self.events = dict.fromkeys(self.events, False) - return self.env.reset(**kwargs) - - def close(self): - """ - Properly clean up the keyboard listener when the environment is closed - """ - if self.listener is not None: - self.listener.stop() - super().close() - - class ResetWrapper(gym.Wrapper): def __init__( self, - env: HILSerlRobotEnv, + env: RobotEnv, reset_pose: np.ndarray | None = None, reset_time_s: float = 5, - open_gripper_on_reset: bool = False, ): super().__init__(env) self.reset_time_s = reset_time_s self.reset_pose = reset_pose self.robot = self.unwrapped.robot - self.open_gripper_on_reset = open_gripper_on_reset def reset(self, *, seed=None, options=None): + start_time = time.perf_counter() if self.reset_pose is not None: - start_time = time.perf_counter() log_say("Reset the environment.", play_sounds=True) - reset_follower_position(self.robot, self.reset_pose) - busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) + reset_follower_position(self.robot.follower_arms["main"], self.reset_pose) log_say("Reset the environment done.", play_sounds=True) - if self.open_gripper_on_reset: - current_joint_pos = self.robot.follower_arms["main"].read("Present_Position") - current_joint_pos[-1] = MAX_GRIPPER_COMMAND - self.robot.send_action(torch.from_numpy(current_joint_pos)) - busy_wait(0.1) - current_joint_pos[-1] = 0.0 - self.robot.send_action(torch.from_numpy(current_joint_pos)) - busy_wait(0.2) + + if len(self.robot.leader_arms) > 0: + self.robot.leader_arms["main"].write("Torque_Enable", 1) + log_say("Reset the leader robot.", play_sounds=True) + reset_follower_position(self.robot.leader_arms["main"], self.reset_pose) + log_say("Reset the leader robot done.", play_sounds=True) else: log_say( f"Manually reset the environment for {self.reset_time_s} seconds.", @@ -785,6 +459,9 @@ class ResetWrapper(gym.Wrapper): self.robot.teleop_step() log_say("Manual reset of the environment done.", play_sounds=True) + + busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) + return super().reset(seed=seed, options=options) @@ -804,10 +481,9 @@ class BatchCompatibleWrapper(gym.ObservationWrapper): class GripperPenaltyWrapper(gym.RewardWrapper): - def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True): + def __init__(self, env, penalty: float = -0.1): super().__init__(env) self.penalty = penalty - self.gripper_penalty_in_reward = gripper_penalty_in_reward self.last_gripper_state = None def reward(self, reward, action): @@ -823,22 +499,18 @@ class GripperPenaltyWrapper(gym.RewardWrapper): def step(self, action): self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] - gripper_action = action[0][-1] if isinstance(action, tuple) else action[-1] + gripper_action = action[-1] obs, reward, terminated, truncated, info = self.env.step(action) gripper_penalty = self.reward(reward, gripper_action) - if self.gripper_penalty_in_reward: - reward += gripper_penalty - else: - info["discrete_penalty"] = gripper_penalty + info["discrete_penalty"] = gripper_penalty return obs, reward, terminated, truncated, info def reset(self, **kwargs): self.last_gripper_state = None obs, info = super().reset(**kwargs) - if self.gripper_penalty_in_reward: - info["gripper_penalty"] = 0.0 + info["gripper_penalty"] = 0.0 return obs, info @@ -851,10 +523,6 @@ class GripperActionWrapper(gym.ActionWrapper): self.last_gripper_action = None def action(self, action): - is_intervention = False - if isinstance(action, tuple): - action, is_intervention = action - if self.gripper_sleep > 0.0: if ( self.last_gripper_action is not None @@ -879,7 +547,7 @@ class GripperActionWrapper(gym.ActionWrapper): gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) action[-1] = gripper_action.item() - return action, is_intervention + return action def reset(self, **kwargs): obs, info = super().reset(**kwargs) @@ -910,24 +578,21 @@ class EEActionWrapper(gym.ActionWrapper): # gripper actions open at 2.0, and closed at 0.0 min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]]) max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]]) - ee_action_space = gym.spaces.Box( + else: + min_action_space_bounds = -action_space_bounds + max_action_space_bounds = action_space_bounds + + self.action_space = gym.spaces.Box( low=min_action_space_bounds, high=max_action_space_bounds, shape=(3 + int(self.use_gripper),), dtype=np.float32, ) - if isinstance(self.action_space, gym.spaces.Tuple): - self.action_space = gym.spaces.Tuple((ee_action_space, self.action_space[1])) - else: - self.action_space = ee_action_space self.bounds = ee_action_space_params.bounds def action(self, action): - is_intervention = False desired_ee_pos = np.eye(4) - if isinstance(action, tuple): - action, _ = action if self.use_gripper: gripper_command = action[-1] @@ -935,8 +600,6 @@ class EEActionWrapper(gym.ActionWrapper): current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position") current_ee_pos = self.fk_function(current_joint_pos) - if isinstance(action, torch.Tensor): - action = action.cpu().numpy() desired_ee_pos[:3, 3] = np.clip( current_ee_pos[:3, 3] + action, self.bounds["min"], @@ -951,7 +614,7 @@ class EEActionWrapper(gym.ActionWrapper): if self.use_gripper: target_joint_pos[-1] = gripper_command - return target_joint_pos, is_intervention + return target_joint_pos class EEObservationWrapper(gym.ObservationWrapper): @@ -986,6 +649,298 @@ class EEObservationWrapper(gym.ObservationWrapper): return observation +########################################################### +# Wrappers related to human intervention and input devices +########################################################### + + +class BaseLeaderControlWrapper(gym.Wrapper): + """Base class for leader-follower robot control wrappers.""" + + def __init__( + self, env, use_geared_leader_arm: bool = False, ee_action_space_params=None, use_gripper=False + ): + super().__init__(env) + self.robot_leader = env.unwrapped.robot.leader_arms["main"] + self.robot_follower = env.unwrapped.robot.follower_arms["main"] + self.use_geared_leader_arm = use_geared_leader_arm + self.ee_action_space_params = ee_action_space_params + self.use_ee_action_space = ee_action_space_params is not None + self.use_gripper: bool = use_gripper + + # Set up keyboard event tracking + self._init_keyboard_events() + self.event_lock = Lock() # Thread-safe access to events + + # Initialize robot control + robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100") + self.kinematics = RobotKinematics(robot_type) + self.prev_leader_ee = None + self.prev_leader_pos = None + self.leader_torque_enabled = True + + # Configure leader arm + # NOTE: Lower the gains of leader arm for automatic take-over + # With lower gains we can manually move the leader arm without risk of injury to ourselves or the robot + # With higher gains, it would be dangerous and difficult to modify the leader's pose while torque is enabled + # Default value for P_coeff is 32 + self.robot_leader.write("Torque_Enable", 1) + self.robot_leader.write("P_Coefficient", 4) + self.robot_leader.write("I_Coefficient", 0) + self.robot_leader.write("D_Coefficient", 4) + + self._init_keyboard_listener() + + def _init_keyboard_events(self): + """Initialize the keyboard events dictionary - override in subclasses.""" + self.keyboard_events = { + "episode_success": False, + "episode_end": False, + "rerecord_episode": False, + } + + def _handle_key_press(self, key, keyboard): + """Handle key presses - override in subclasses for additional keys.""" + try: + if key == keyboard.Key.esc: + self.keyboard_events["episode_end"] = True + return + if key == keyboard.Key.left: + self.keyboard_events["rerecord_episode"] = True + return + if hasattr(key, "char") and key.char == "s": + logging.info("Key 's' pressed. Episode success triggered.") + self.keyboard_events["episode_success"] = True + return + except Exception as e: + logging.error(f"Error handling key press: {e}") + + def _init_keyboard_listener(self): + """Initialize keyboard listener if not in headless mode""" + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + return + try: + from pynput import keyboard + + def on_press(key): + with self.event_lock: + self._handle_key_press(key, keyboard) + + self.listener = keyboard.Listener(on_press=on_press) + self.listener.start() + + except ImportError: + logging.warning("Could not import pynput. Keyboard interface will not be available.") + self.listener = None + + def _check_intervention(self): + """Check if intervention is needed - override in subclasses.""" + return False + + def _handle_intervention(self, action): + """Process actions during intervention mode.""" + if self.leader_torque_enabled: + self.robot_leader.write("Torque_Enable", 0) + self.leader_torque_enabled = False + + leader_pos = self.robot_leader.read("Present_Position") + follower_pos = self.robot_follower.read("Present_Position") + + # [:3, 3] Last column of the transformation matrix corresponds to the xyz translation + leader_ee = self.kinematics.fk_gripper_tip(leader_pos)[:3, 3] + follower_ee = self.kinematics.fk_gripper_tip(follower_pos)[:3, 3] + + if self.prev_leader_ee is None: + self.prev_leader_ee = leader_ee + + # NOTE: Using the leader's position delta for teleoperation is too noisy + # Instead, we move the follower to match the leader's absolute position, + # and record the leader's position changes as the intervention action + action = leader_ee - follower_ee + action_intervention = leader_ee - self.prev_leader_ee + self.prev_leader_ee = leader_ee + + if self.use_gripper: + # Get gripper action delta based on leader pose + leader_gripper = leader_pos[-1] + follower_gripper = follower_pos[-1] + gripper_delta = leader_gripper - follower_gripper + + # Normalize by max angle and quantize to {0,1,2} + normalized_delta = gripper_delta / MAX_GRIPPER_COMMAND + if normalized_delta > 0.3: + gripper_action = 2 + elif normalized_delta < -0.3: + gripper_action = 0 + else: + gripper_action = 1 + + action = np.append(action, gripper_action) + action_intervention = np.append(action_intervention, gripper_delta) + + return action, action_intervention + + def _handle_leader_teleoperation(self): + """Handle leader teleoperation (non-intervention) operation.""" + if not self.leader_torque_enabled: + self.robot_leader.write("Torque_Enable", 1) + self.leader_torque_enabled = True + + follower_pos = self.robot_follower.read("Present_Position") + self.robot_leader.write("Goal_Position", follower_pos) + + def step(self, action): + """Execute environment step with possible intervention.""" + is_intervention = self._check_intervention() + action_intervention = None + + # NOTE: + if is_intervention: + action, action_intervention = self._handle_intervention(action) + else: + self._handle_leader_teleoperation() + + # NOTE: + obs, reward, terminated, truncated, info = self.env.step(action) + + # Add intervention info + info["is_intervention"] = is_intervention + info["action_intervention"] = action_intervention if is_intervention else None + + # Check for success or manual termination + success = self.keyboard_events["episode_success"] + terminated = terminated or self.keyboard_events["episode_end"] or success + + if success: + reward = 1.0 + logging.info("Episode ended successfully with reward 1.0") + + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + """Reset the environment and internal state.""" + self.prev_leader_ee = None + self.prev_leader_pos = None + self.keyboard_events = dict.fromkeys(self.keyboard_events, False) + return super().reset(**kwargs) + + def close(self): + """Clean up resources.""" + if hasattr(self, "listener") and self.listener is not None: + self.listener.stop() + return self.env.close() + + +class GearedLeaderControlWrapper(BaseLeaderControlWrapper): + """Wrapper that enables manual intervention via keyboard.""" + + def _init_keyboard_events(self): + """Initialize keyboard events including human intervention flag.""" + super()._init_keyboard_events() + self.keyboard_events["human_intervention_step"] = False + + def _handle_key_press(self, key, keyboard): + """Handle key presses including space for intervention toggle.""" + super()._handle_key_press(key, keyboard) + if key == keyboard.Key.space: + if not self.keyboard_events["human_intervention_step"]: + logging.info( + "Space key pressed. Human intervention required.\n" + "Place the leader in similar pose to the follower and press space again." + ) + self.keyboard_events["human_intervention_step"] = True + log_say("Human intervention step.", play_sounds=True) + else: + self.keyboard_events["human_intervention_step"] = False + logging.info("Space key pressed for a second time.\nContinuing with policy actions.") + log_say("Continuing with policy actions.", play_sounds=True) + + def _check_intervention(self): + """Check if human intervention is active.""" + return self.keyboard_events["human_intervention_step"] + + +class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper): + """Wrapper with automatic intervention based on error thresholds.""" + + def __init__( + self, + env, + ee_action_space_params=None, + use_gripper=False, + intervention_threshold=1.7, + release_threshold=0.01, + queue_size=10, + ): + super().__init__(env, ee_action_space_params=ee_action_space_params, use_gripper=use_gripper) + + # Error tracking parameters + self.intervention_threshold = intervention_threshold # Threshold to trigger intervention + self.release_threshold = release_threshold # Threshold to release intervention + self.queue_size = queue_size # Number of error measurements to keep + + # Error tracking variables + self.error_queue = deque(maxlen=self.queue_size) + self.error_over_time_queue = deque(maxlen=self.queue_size) + self.previous_error = 0.0 + self.is_intervention_active = False + self.start_time = time.perf_counter() + + def _check_intervention(self): + """Determine if intervention should occur based on leader-follower error.""" + # Skip intervention logic for the first few steps to collect data + if time.perf_counter() - self.start_time < 1.0: # Wait 1 second before enabling + return False + + # Get current positions + leader_positions = self.robot_leader.read("Present_Position") + follower_positions = self.robot_follower.read("Present_Position") + + # Calculate error and error rate + error = np.linalg.norm(leader_positions - follower_positions) + error_over_time = np.abs(error - self.previous_error) + + # Add to queue for running average + self.error_queue.append(error) + self.error_over_time_queue.append(error_over_time) + + # Update previous error + self.previous_error = error + + # Calculate averages if we have enough data + if len(self.error_over_time_queue) >= self.queue_size: + avg_error_over_time = np.mean(self.error_over_time_queue) + + # Debug info + if self.is_intervention_active: + logging.debug(f"Error rate during intervention: {avg_error_over_time:.4f}") + + # Determine if intervention should start or stop + if not self.is_intervention_active and avg_error_over_time > self.intervention_threshold: + # Transition to intervention mode + self.is_intervention_active = True + logging.info(f"Starting automatic intervention: error rate {avg_error_over_time:.4f}") + + elif self.is_intervention_active and avg_error_over_time < self.release_threshold: + # End intervention mode + self.is_intervention_active = False + logging.info(f"Ending automatic intervention: error rate {avg_error_over_time:.4f}") + + return self.is_intervention_active + + def reset(self, **kwargs): + """Reset error tracking on environment reset.""" + self.error_queue.clear() + self.error_over_time_queue.clear() + self.previous_error = 0.0 + self.is_intervention_active = False + self.start_time = time.perf_counter() + return super().reset(**kwargs) + + class GamepadControlWrapper(gym.Wrapper): """ Wrapper that allows controlling a gym environment with a gamepad. @@ -1049,7 +1004,9 @@ class GamepadControlWrapper(gym.Wrapper): print(" Y/Triangle button: End episode (SUCCESS)") print(" B/Circle button: Exit program") - def get_gamepad_action(self): + def get_gamepad_action( + self, + ) -> Tuple[bool, np.ndarray, bool, bool, bool]: """ Get the current action from the gamepad if any input is active. @@ -1115,20 +1072,10 @@ class GamepadControlWrapper(gym.Wrapper): logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}") # Only override the action if gamepad is active - if is_intervention: - # Format according to the expected action type - if isinstance(self.action_space, gym.spaces.Tuple): - # For environments that use (action, is_intervention) tuples - final_action = (torch.from_numpy(gamepad_action), False) - else: - final_action = torch.from_numpy(gamepad_action) - - else: - # Use the original action - final_action = action + action = gamepad_action if is_intervention else action # Step the environment - obs, reward, terminated, truncated, info = self.env.step(final_action) + obs, reward, terminated, truncated, info = self.env.step(action) # Add episode ending if requested via gamepad terminated = terminated or truncated or terminate_episode @@ -1137,11 +1084,11 @@ class GamepadControlWrapper(gym.Wrapper): reward = 1.0 logging.info("Episode ended successfully with reward 1.0") + if isinstance(action, np.ndarray): + action = torch.from_numpy(action) + info["is_intervention"] = is_intervention - action_intervention = final_action[0] if isinstance(final_action, Tuple) else final_action - if isinstance(action_intervention, np.ndarray): - action_intervention = torch.from_numpy(action_intervention) - info["action_intervention"] = action_intervention + info["action_intervention"] = action info["rerecord_episode"] = rerecord_episode # If episode ended, reset the state @@ -1166,26 +1113,73 @@ class GamepadControlWrapper(gym.Wrapper): return self.env.close() -class ActionScaleWrapper(gym.ActionWrapper): - def __init__(self, env, ee_action_space_params=None): - super().__init__(env) - assert ee_action_space_params is not None, "TODO: method implemented for ee action space only so far" - self.scale_vector = np.array( - [ - [ - ee_action_space_params.x_step_size, - ee_action_space_params.y_step_size, - ee_action_space_params.z_step_size, - ] - ] +class TorchBox(gym.spaces.Box): + """A version of gym.spaces.Box that handles PyTorch tensors. + + This class extends gym.spaces.Box to work with PyTorch tensors, + providing compatibility between NumPy arrays and PyTorch tensors. + """ + + def __init__( + self, + low: float | Sequence[float] | np.ndarray, + high: float | Sequence[float] | np.ndarray, + shape: Sequence[int] | None = None, + np_dtype: np.dtype | type = np.float32, + torch_dtype: torch.dtype = torch.float32, + device: str = "cpu", + seed: int | np.random.Generator | None = None, + ) -> None: + super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed) + self.torch_dtype = torch_dtype + self.device = device + + def sample(self) -> torch.Tensor: + arr = super().sample() + return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device) + + def contains(self, x: torch.Tensor) -> bool: + # Move to CPU/numpy and cast to the internal dtype + arr = x.detach().cpu().numpy().astype(self.dtype, copy=False) + return super().contains(arr) + + def seed(self, seed: int | np.random.Generator | None = None): + super().seed(seed) + return [seed] + + def __repr__(self) -> str: + return ( + f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, " + f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})" ) - def action(self, action): - is_intervention = False - if isinstance(action, tuple): - action, is_intervention = action - return action * self.scale_vector, is_intervention +class TorchActionWrapper(gym.Wrapper): + """ + The goal of this wrapper is to change the action_space.sample() + to torch tensors. + """ + + def __init__(self, env: gym.Env, device: str): + super().__init__(env) + self.action_space = TorchBox( + low=env.action_space.low, + high=env.action_space.high, + shape=env.action_space.shape, + torch_dtype=torch.float32, + device=torch.device("cpu"), + ) + + def step(self, action: torch.Tensor): + if action.dim() == 2: + action = action.squeeze(0) + action = action.detach().cpu().numpy() + return self.env.step(action) + + +########################################################### +# Factory functions +########################################################### def make_robot_env(cfg) -> gym.vector.VectorEnv: @@ -1200,22 +1194,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: Returns: A vectorized gym environment with all the necessary wrappers applied. """ - if "maniskill" in cfg.name: - from lerobot.scripts.server.maniskill_manipulator import make_maniskill - - logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") - env = make_maniskill( - cfg=cfg, - n_envs=1, - ) - return env robot = make_robot_from_config(cfg.robot) # Create base environment - env = HILSerlRobotEnv( + env = RobotEnv( robot=robot, display_cameras=cfg.wrapper.display_cameras, - use_delta_action_space=cfg.wrapper.use_relative_joint_positions - and cfg.wrapper.ee_action_space_params is None, ) # Add observation and image processing @@ -1246,18 +1229,15 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env = GripperPenaltyWrapper( env=env, penalty=cfg.wrapper.gripper_penalty, - gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward, ) - if cfg.wrapper.ee_action_space_params is not None: - env = EEActionWrapper( - env=env, - ee_action_space_params=cfg.wrapper.ee_action_space_params, - use_gripper=cfg.wrapper.use_gripper, - ) + env = EEActionWrapper( + env=env, + ee_action_space_params=cfg.wrapper.ee_action_space_params, + use_gripper=cfg.wrapper.use_gripper, + ) - if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad: - # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) + if cfg.wrapper.ee_action_space_params.control_mode == "gamepad": env = GamepadControlWrapper( env=env, x_step_size=cfg.wrapper.ee_action_space_params.x_step_size, @@ -1265,18 +1245,28 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: z_step_size=cfg.wrapper.ee_action_space_params.z_step_size, use_gripper=cfg.wrapper.use_gripper, ) + elif cfg.wrapper.ee_action_space_params.control_mode == "leader": + env = GearedLeaderControlWrapper( + env=env, + ee_action_space_params=cfg.wrapper.ee_action_space_params, + use_gripper=cfg.wrapper.use_gripper, + ) + elif cfg.wrapper.ee_action_space_params.control_mode == "leader_automatic": + env = GearedLeaderAutomaticControlWrapper( + env=env, + ee_action_space_params=cfg.wrapper.ee_action_space_params, + use_gripper=cfg.wrapper.use_gripper, + ) else: - env = KeyboardInterfaceWrapper(env=env) + raise ValueError(f"Invalid control mode: {cfg.wrapper.ee_action_space_params.control_mode}") env = ResetWrapper( env=env, reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_time_s=cfg.wrapper.reset_time_s, - open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset, ) - if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: - env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompatibleWrapper(env=env) + env = TorchActionWrapper(env=env, device=cfg.device) return env @@ -1311,6 +1301,11 @@ def init_reward_classifier(cfg): return classifier +########################################################### +# Record and replay functions +########################################################### + + def record_dataset(env, policy, cfg): """ Record a dataset of robot interactions using either a policy or teleop. @@ -1329,9 +1324,7 @@ def record_dataset(env, policy, cfg): from lerobot.common.datasets.lerobot_dataset import LeRobotDataset # Setup initial action (zero action if using teleop) - dummy_action = env.action_space.sample() - dummy_action = (torch.from_numpy(dummy_action[0] * 0.0), False) - action = dummy_action + action = env.action_space.sample() * 0.0 # Configure dataset features based on environment spaces features = { @@ -1342,7 +1335,7 @@ def record_dataset(env, policy, cfg): }, "action": { "dtype": "float32", - "shape": env.action_space[0].shape, + "shape": env.action_space.shape, "names": None, }, "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, @@ -1442,8 +1435,7 @@ def replay_episode(env, cfg): start_episode_t = time.perf_counter() action = actions[idx]["action"] - env.step((action, False)) - # env.step((action / env.unwrapped.delta, False)) + env.step(action) dt_s = time.perf_counter() - start_episode_t busy_wait(1 / 10 - dt_s) @@ -1464,7 +1456,7 @@ def main(cfg: EnvConfig): record_dataset( env, - policy=None, + policy=policy, cfg=cfg, ) exit() @@ -1478,11 +1470,8 @@ def main(cfg: EnvConfig): env.reset() - # Retrieve the robot's action space for joint commands. - action_space_robot = env.action_space.spaces[0] - # Initialize the smoothed action as a random sample. - smoothed_action = action_space_robot.sample() + smoothed_action = env.action_space.sample() # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. @@ -1493,12 +1482,12 @@ def main(cfg: EnvConfig): while num_episode < 20: start_loop_s = time.perf_counter() # Sample a new random action from the robot's action space. - new_random_action = action_space_robot.sample() + new_random_action = env.action_space.sample() # Update the smoothed action using an exponential moving average. smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action # Execute the step: wrap the NumPy action in a torch tensor. - obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False)) + obs, reward, terminated, truncated, info = env.step(smoothed_action) if terminated or truncated: successes.append(reward) env.reset() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index b9247fa8..777cb92e 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -23,8 +23,6 @@ from pathlib import Path from pprint import pformat import grpc - -# Import generated stubs import hilserl_pb2_grpc # type: ignore import torch from termcolor import colored @@ -39,8 +37,6 @@ from lerobot.common.constants import ( TRAINING_STATE_DIR, ) from lerobot.common.datasets.factory import make_dataset - -# TODO: Remove the import of maniskill from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy @@ -62,16 +58,17 @@ from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig from lerobot.scripts.server import learner_service -from lerobot.scripts.server.buffer import ( - ReplayBuffer, +from lerobot.scripts.server.buffer import ReplayBuffer, concatenate_batch_transitions +from lerobot.scripts.server.network_utils import ( bytes_to_python_object, bytes_to_transitions, - concatenate_batch_transitions, - move_state_dict_to_device, - move_transition_to_device, state_to_bytes, ) -from lerobot.scripts.server.utils import setup_process_handlers +from lerobot.scripts.server.utils import ( + move_state_dict_to_device, + move_transition_to_device, + setup_process_handlers, +) LOG_PREFIX = "[LEARNER]" @@ -307,17 +304,10 @@ def add_actor_information_and_train( offline_replay_buffer = None if cfg.dataset is not None: - active_action_dims = None - # TODO: FIX THIS - if cfg.env.wrapper.joint_masking_action_space is not None: - active_action_dims = [ - i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask - ] offline_replay_buffer = initialize_offline_replay_buffer( cfg=cfg, device=device, storage_device=storage_device, - active_action_dims=active_action_dims, ) batch_size: int = batch_size // 2 # We will sample from both replay buffer @@ -342,7 +332,6 @@ def add_actor_information_and_train( break # Process all available transitions to the replay buffer, send by the actor server - logging.debug("[LEARNER] Waiting for transitions") process_transitions( transition_queue=transition_queue, replay_buffer=replay_buffer, @@ -351,35 +340,29 @@ def add_actor_information_and_train( dataset_repo_id=dataset_repo_id, shutdown_event=shutdown_event, ) - logging.debug("[LEARNER] Received transitions") # Process all available interaction messages sent by the actor server - logging.debug("[LEARNER] Waiting for interactions") interaction_message = process_interaction_messages( interaction_message_queue=interaction_message_queue, interaction_step_shift=interaction_step_shift, wandb_logger=wandb_logger, shutdown_event=shutdown_event, ) - logging.debug("[LEARNER] Received interactions") # Wait until the replay buffer has enough samples to start training if len(replay_buffer) < online_step_before_learning: continue if online_iterator is None: - logging.debug("[LEARNER] Initializing online replay buffer iterator") online_iterator = replay_buffer.get_iterator( batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 ) if offline_replay_buffer is not None and offline_iterator is None: - logging.debug("[LEARNER] Initializing offline replay buffer iterator") offline_iterator = offline_replay_buffer.get_iterator( batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 ) - logging.debug("[LEARNER] Starting optimization loop") time_for_one_optimization_step = time.time() for _ in range(utd_ratio - 1): # Sample from the iterators @@ -967,7 +950,6 @@ def initialize_offline_replay_buffer( cfg: TrainPipelineConfig, device: str, storage_device: str, - active_action_dims: list[int] | None = None, ) -> ReplayBuffer: """ Initialize an offline replay buffer from a dataset. @@ -976,7 +958,6 @@ def initialize_offline_replay_buffer( cfg (TrainPipelineConfig): Training configuration device (str): Device to store tensors on storage_device (str): Device for storage optimization - active_action_dims (list[int] | None): Active action dimensions for masking Returns: ReplayBuffer: Initialized offline replay buffer @@ -997,7 +978,6 @@ def initialize_offline_replay_buffer( offline_dataset, device=device, state_keys=cfg.policy.input_features.keys(), - action_mask=active_action_dims, storage_device=storage_device, optimize_memory=True, capacity=cfg.policy.offline_buffer_capacity, @@ -1096,44 +1076,6 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): parameters_queue.put(state_bytes) -def check_weight_gradients(module: nn.Module) -> dict[str, bool]: - """ - Checks whether each parameter in the module has a gradient. - - Args: - module (nn.Module): A PyTorch module whose parameters will be inspected. - - Returns: - dict[str, bool]: A dictionary where each key is the parameter name and the value is - True if the parameter has an associated gradient (i.e. .grad is not None), - otherwise False. - """ - grad_status = {} - for name, param in module.named_parameters(): - grad_status[name] = param.grad is not None - return grad_status - - -def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]: - """ - Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary. - - Args: - actor (nn.Module): The actor model. - grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate - whether each parameter has a gradient. - - Returns: - dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status. - """ - # Get actor parameter names as a set. - model_param_names = {name for name, _ in model.named_parameters()} - - # Intersect parameter names between actor and grad_status. - overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names} - return overlapping - - def process_interaction_message( message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None ): diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py deleted file mode 100644 index b42d347b..00000000 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ /dev/null @@ -1,221 +0,0 @@ -from typing import Any - -import einops -import gymnasium as gym -import numpy as np -import torch -from mani_skill.utils.wrappers.record import RecordEpisode -from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv - -from lerobot.common.envs.configs import ManiskillEnvConfig - - -def preprocess_maniskill_observation( - observations: dict[str, np.ndarray], -) -> dict[str, torch.Tensor]: - """Convert environment observation to LeRobot format observation. - Args: - observation: Dictionary of observation batches from a Gym vector environment. - Returns: - Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. - """ - # map to expected inputs for the policy - return_observations = {} - # TODO: You have to merge all tensors from agent key and extra key - # You don't keep sensor param key in the observation - # And you keep sensor data rgb - q_pos = observations["agent"]["qpos"] - q_vel = observations["agent"]["qvel"] - tcp_pos = observations["extra"]["tcp_pose"] - img = observations["sensor_data"]["base_camera"]["rgb"] - - _, h, w, c = img.shape - assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" - - # sanity check that images are uint8 - assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" - - # convert to channel first of type float32 in range [0,1] - img = einops.rearrange(img, "b h w c -> b c h w").contiguous() - img = img.type(torch.float32) - img /= 255 - - state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1) - - return_observations["observation.image"] = img - return_observations["observation.state"] = state - return return_observations - - -class ManiSkillObservationWrapper(gym.ObservationWrapper): - def __init__(self, env, device: torch.device = "cuda"): - super().__init__(env) - if isinstance(device, str): - device = torch.device(device) - self.device = device - - def observation(self, observation): - observation = preprocess_maniskill_observation(observation) - observation = {k: v.to(self.device) for k, v in observation.items()} - return observation - - -class ManiSkillCompat(gym.Wrapper): - def __init__(self, env): - super().__init__(env) - new_action_space_shape = env.action_space.shape[-1] - new_low = np.squeeze(env.action_space.low, axis=0) - new_high = np.squeeze(env.action_space.high, axis=0) - self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,)) - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[Any, dict[str, Any]]: - options = {} - return super().reset(seed=seed, options=options) - - def step(self, action): - obs, reward, terminated, truncated, info = self.env.step(action) - reward = reward.item() - terminated = terminated.item() - truncated = truncated.item() - return obs, reward, terminated, truncated, info - - -class ManiSkillActionWrapper(gym.ActionWrapper): - def __init__(self, env): - super().__init__(env) - self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2))) - - def action(self, action): - action, telop = action - return action - - -class ManiSkillMultiplyActionWrapper(gym.Wrapper): - def __init__(self, env, multiply_factor: float = 1): - super().__init__(env) - self.multiply_factor = multiply_factor - action_space_agent: gym.spaces.Box = env.action_space[0] - action_space_agent.low = action_space_agent.low * multiply_factor - action_space_agent.high = action_space_agent.high * multiply_factor - self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2))) - - def step(self, action): - if isinstance(action, tuple): - action, telop = action - else: - telop = 0 - action = action / self.multiply_factor - obs, reward, terminated, truncated, info = self.env.step((action, telop)) - return obs, reward, terminated, truncated, info - - -class BatchCompatibleWrapper(gym.ObservationWrapper): - """Ensures observations are batch-compatible by adding a batch dimension if necessary.""" - - def __init__(self, env): - super().__init__(env) - - def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - for key in observation: - if "image" in key and observation[key].dim() == 3: - observation[key] = observation[key].unsqueeze(0) - if "state" in key and observation[key].dim() == 1: - observation[key] = observation[key].unsqueeze(0) - return observation - - -class TimeLimitWrapper(gym.Wrapper): - """Adds a time limit to the environment based on fps and control_time.""" - - def __init__(self, env, control_time_s, fps): - super().__init__(env) - self.control_time_s = control_time_s - self.fps = fps - self.max_episode_steps = int(self.control_time_s * self.fps) - self.current_step = 0 - - def step(self, action): - obs, reward, terminated, truncated, info = self.env.step(action) - self.current_step += 1 - - if self.current_step >= self.max_episode_steps: - terminated = True - - return obs, reward, terminated, truncated, info - - def reset(self, *, seed=None, options=None): - self.current_step = 0 - return super().reset(seed=seed, options=options) - - -class ManiskillMockGripperWrapper(gym.Wrapper): - def __init__(self, env, nb_discrete_actions: int = 3): - super().__init__(env) - new_shape = env.action_space[0].shape[0] + 1 - new_low = np.concatenate([env.action_space[0].low, [0]]) - new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]]) - action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,)) - self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1])) - - def step(self, action): - if isinstance(action, tuple): - action_agent, telop_action = action - else: - telop_action = 0 - action_agent = action - real_action = action_agent[:-1] - final_action = (real_action, telop_action) - obs, reward, terminated, truncated, info = self.env.step(final_action) - return obs, reward, terminated, truncated, info - - -def make_maniskill( - cfg: ManiskillEnvConfig, - n_envs: int | None = None, -) -> gym.Env: - """ - Factory function to create a ManiSkill environment with standard wrappers. - - Args: - cfg: Configuration for the ManiSkill environment - n_envs: Number of parallel environments - - Returns: - A wrapped ManiSkill environment - """ - env = gym.make( - cfg.task, - obs_mode=cfg.obs_type, - control_mode=cfg.control_mode, - render_mode=cfg.render_mode, - sensor_configs={"width": cfg.image_size, "height": cfg.image_size}, - num_envs=n_envs, - ) - - # Add video recording if enabled - if cfg.video_record.enabled: - env = RecordEpisode( - env, - output_dir=cfg.video_record.record_dir, - save_trajectory=True, - trajectory_name=cfg.video_record.trajectory_name, - save_video=True, - video_fps=30, - ) - - # Add observation and image processing - env = ManiSkillObservationWrapper(env, device=cfg.device) - env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) - env._max_episode_steps = env.max_episode_steps = cfg.episode_length - env.unwrapped.metadata["render_fps"] = cfg.fps - - # Add compatibility wrappers - env = ManiSkillCompat(env) - env = ManiSkillActionWrapper(env) - env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control - if cfg.mock_gripper: - env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3) - - return env diff --git a/lerobot/scripts/server/network_utils.py b/lerobot/scripts/server/network_utils.py index 78b9e5db..b24c8f17 100644 --- a/lerobot/scripts/server/network_utils.py +++ b/lerobot/scripts/server/network_utils.py @@ -17,10 +17,14 @@ import io import logging +import pickle # nosec B403: Safe usage for internal serialization only from multiprocessing import Event, Queue from typing import Any +import torch + from lerobot.scripts.server import hilserl_pb2 +from lerobot.scripts.server.utils import Transition CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB @@ -60,7 +64,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "" logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") -def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): +def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore bytes_buffer = io.BytesIO() step = 0 @@ -93,3 +97,44 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p step = 0 logging.debug(f"{log_prefix} Queue updated") + + +def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: + """Convert model state dict to flat array for transmission""" + buffer = io.BytesIO() + + torch.save(state_dict, buffer) + + return buffer.getvalue() + + +def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + return torch.load(buffer) # nosec B614: Safe usage of torch.load + + +def python_object_to_bytes(python_object: Any) -> bytes: + return pickle.dumps(python_object) + + +def bytes_to_python_object(buffer: bytes) -> Any: + buffer = io.BytesIO(buffer) + buffer.seek(0) + obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load + # Add validation checks here + return obj + + +def bytes_to_transitions(buffer: bytes) -> list[Transition]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load + # Add validation checks here + return transitions + + +def transitions_to_bytes(transitions: list[Transition]) -> bytes: + buffer = io.BytesIO() + torch.save(transitions, buffer) + return buffer.getvalue() diff --git a/lerobot/scripts/server/utils.py b/lerobot/scripts/server/utils.py index 2ce4e57f..a9486b6c 100644 --- a/lerobot/scripts/server/utils.py +++ b/lerobot/scripts/server/utils.py @@ -19,7 +19,9 @@ import logging import signal import sys from queue import Empty +from typing import TypedDict +import torch from torch.multiprocessing import Queue shutdown_event_counter = 0 @@ -71,3 +73,69 @@ def get_last_item_from_queue(queue: Queue): logging.debug(f"Drained {counter} items from queue") return item + + +class Transition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: float + next_state: dict[str, torch.Tensor] + done: bool + truncated: bool + complementary_info: dict[str, torch.Tensor | float | int] | None = None + + +def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: + device = torch.device(device) + non_blocking = device.type == "cuda" + + # Move state tensors to device + transition["state"] = { + key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items() + } + + # Move action to device + transition["action"] = transition["action"].to(device, non_blocking=non_blocking) + + # Move reward and done if they are tensors + if isinstance(transition["reward"], torch.Tensor): + transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking) + + if isinstance(transition["done"], torch.Tensor): + transition["done"] = transition["done"].to(device, non_blocking=non_blocking) + + if isinstance(transition["truncated"], torch.Tensor): + transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking) + + # Move next_state tensors to device + transition["next_state"] = { + key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items() + } + + # Move complementary_info tensors if present + if transition.get("complementary_info") is not None: + for key, val in transition["complementary_info"].items(): + if isinstance(val, torch.Tensor): + transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) + elif isinstance(val, (int, float, bool)): + transition["complementary_info"][key] = torch.tensor(val, device=device) + else: + raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") + return transition + + +def move_state_dict_to_device(state_dict, device="cpu"): + """ + Recursively move all tensors in a (potentially) nested + dict/list/tuple structure to the CPU. + """ + if isinstance(state_dict, torch.Tensor): + return state_dict.to(device) + elif isinstance(state_dict, dict): + return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()} + elif isinstance(state_dict, list): + return [move_state_dict_to_device(v, device=device) for v in state_dict] + elif isinstance(state_dict, tuple): + return tuple(move_state_dict_to_device(v, device=device) for v in state_dict) + else: + return state_dict