From 7ed7570b17b3d896c490a1408efff9810318560f Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 4 Mar 2025 11:42:07 +0100 Subject: [PATCH] WIP Add stretch --- .../robots/stretch3/configuration_stretch3.py | 12 +- .../common/robots/stretch3/robot_stretch3.py | 170 ++++++++---------- 2 files changed, 83 insertions(+), 99 deletions(-) diff --git a/lerobot/common/robots/stretch3/configuration_stretch3.py b/lerobot/common/robots/stretch3/configuration_stretch3.py index 8520c471..79d54670 100644 --- a/lerobot/common/robots/stretch3/configuration_stretch3.py +++ b/lerobot/common/robots/stretch3/configuration_stretch3.py @@ -1,7 +1,10 @@ from dataclasses import dataclass, field -from lerobot.common.cameras.configs import CameraConfig, IntelRealSenseCameraConfig, OpenCVCameraConfig -from lerobot.common.robots.config_abc import RobotConfig +from lerobot.common.cameras import CameraConfig +from lerobot.common.cameras.intel import RealSenseCameraConfig +from lerobot.common.cameras.opencv import OpenCVCameraConfig + +from ..config import RobotConfig @RobotConfig.register_subclass("stretch") @@ -12,6 +15,7 @@ class StretchRobotConfig(RobotConfig): # the number of motors in your follower arms. max_relative_target: int | None = None + # cameras cameras: dict[str, CameraConfig] = field( default_factory=lambda: { "navigation": OpenCVCameraConfig( @@ -21,14 +25,14 @@ class StretchRobotConfig(RobotConfig): height=720, rotation=-90, ), - "head": IntelRealSenseCameraConfig( + "head": RealSenseCameraConfig( name="Intel RealSense D435I", fps=30, width=640, height=480, rotation=90, ), - "wrist": IntelRealSenseCameraConfig( + "wrist": RealSenseCameraConfig( name="Intel RealSense D405", fps=30, width=640, diff --git a/lerobot/common/robots/stretch3/robot_stretch3.py b/lerobot/common/robots/stretch3/robot_stretch3.py index c3be2d9c..ffbd6078 100644 --- a/lerobot/common/robots/stretch3/robot_stretch3.py +++ b/lerobot/common/robots/stretch3/robot_stretch3.py @@ -15,33 +15,55 @@ # limitations under the License. import time -from dataclasses import replace +import numpy as np import torch from stretch_body.gamepad_teleop import GamePadTeleop from stretch_body.robot import Robot as StretchAPI from stretch_body.robot_params import RobotParams +from lerobot.common.cameras.utils import make_cameras_from_configs +from lerobot.common.datasets.utils import get_nested_item + +from ..robot import Robot from .configuration_stretch3 import StretchRobotConfig +# {lerobot_keys: stretch.api.keys} +STRETCH_MOTORS = { + "head_pan.pos": "head.head_pan.pos", + "head_tilt.pos": "head.head_tilt.pos", + "lift.pos": "lift.pos", + "arm.pos": "arm.pos", + "wrist_pitch.pos": "end_of_arm.wrist_pitch.pos", + "wrist_roll.pos": "end_of_arm.wrist_roll.pos", + "wrist_yaw.pos": "end_of_arm.wrist_yaw.pos", + "gripper.pos": "end_of_arm.stretch_gripper.pos", + "base_x.vel": "base.x_vel", + "base_y.vel": "base.y_vel", + "base_theta.vel": "base.theta_vel", +} -class StretchRobot(StretchAPI): - """Wrapper of stretch_body.robot.Robot""" - def __init__(self, config: StretchRobotConfig | None = None, **kwargs): - super().__init__() - if config is None: - self.config = StretchRobotConfig(**kwargs) - else: - # Overwrite config arguments using kwargs - self.config = replace(config, **kwargs) +class Stretch3Robot(Robot): + """[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot.""" + config_class = StretchRobotConfig + name = "stretch3" + + def __init__(self, config: StretchRobotConfig): + super().__init__(config) + + self.config = config self.robot_type = self.config.type - self.cameras = self.config.cameras + + self.api = StretchAPI() + self.cameras = make_cameras_from_configs(config.cameras) + self.is_connected = False - self.teleop = None self.logs = {} + self.teleop = None # TODO remove + # TODO(aliberts): test this RobotParams.set_logging_level("WARNING") RobotParams.set_logging_formatter("brief_console_formatter") @@ -49,94 +71,58 @@ class StretchRobot(StretchAPI): self.state_keys = None self.action_keys = None + @property + def state_feature(self) -> dict: + return { + "dtype": "float32", + "shape": (len(STRETCH_MOTORS),), + "names": {"motors": list(STRETCH_MOTORS)}, + } + + @property + def action_feature(self) -> dict: + return self.state_feature + + @property + def camera_features(self) -> dict[str, dict]: + cam_ft = {} + for cam_key, cam in self.cameras.items(): + cam_ft[cam_key] = { + "shape": (cam.height, cam.width, cam.channels), + "names": ["height", "width", "channels"], + "info": None, + } + return cam_ft + def connect(self) -> None: - self.is_connected = self.startup() + self.is_connected = self.api.startup() if not self.is_connected: print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") raise ConnectionError() - for name in self.cameras: - self.cameras[name].connect() - self.is_connected = self.is_connected and self.cameras[name].is_connected + for cam in self.cameras.values(): + cam.connect() + self.is_connected = self.is_connected and cam.is_connected if not self.is_connected: print("Could not connect to the cameras, check that all cameras are plugged-in.") raise ConnectionError() - self.run_calibration() + self.calibrate() - def run_calibration(self) -> None: - if not self.is_homed(): - self.home() + def calibrate(self) -> None: + if not self.api.is_homed(): + self.api.home() - def teleop_step( - self, record_data=False - ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - # TODO(aliberts): return ndarrays instead of torch.Tensors - if not self.is_connected: - raise ConnectionError() + def _get_state(self) -> dict: + status = self.api.get_status() + return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()} - if self.teleop is None: - self.teleop = GamePadTeleop(robot_instance=False) - self.teleop.startup(robot=self) + def get_observation(self) -> dict[str, np.ndarray]: + obs_dict = {} before_read_t = time.perf_counter() - state = self.get_state() - action = self.teleop.gamepad_controller.get_state() - self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t - - before_write_t = time.perf_counter() - self.teleop.do_motion(robot=self) - self.push_command() - self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t - - if self.state_keys is None: - self.state_keys = list(state) - - if not record_data: - return - - state = torch.as_tensor(list(state.values())) - action = torch.as_tensor(list(action.values())) - - # Capture images from cameras - images = {} - for name in self.cameras: - before_camread_t = time.perf_counter() - images[name] = self.cameras[name].async_read() - images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t - - # Populate output dictionaries - obs_dict, action_dict = {}, {} - obs_dict["observation.state"] = state - action_dict["action"] = action - for name in self.cameras: - obs_dict[f"observation.images.{name}"] = images[name] - - return obs_dict, action_dict - - def get_state(self) -> dict: - status = self.get_status() - return { - "head_pan.pos": status["head"]["head_pan"]["pos"], - "head_tilt.pos": status["head"]["head_tilt"]["pos"], - "lift.pos": status["lift"]["pos"], - "arm.pos": status["arm"]["pos"], - "wrist_pitch.pos": status["end_of_arm"]["wrist_pitch"]["pos"], - "wrist_roll.pos": status["end_of_arm"]["wrist_roll"]["pos"], - "wrist_yaw.pos": status["end_of_arm"]["wrist_yaw"]["pos"], - "gripper.pos": status["end_of_arm"]["stretch_gripper"]["pos"], - "base_x.vel": status["base"]["x_vel"], - "base_y.vel": status["base"]["y_vel"], - "base_theta.vel": status["base"]["theta_vel"], - } - - def capture_observation(self) -> dict: - # TODO(aliberts): return ndarrays instead of torch.Tensors - before_read_t = time.perf_counter() - state = self.get_state() + state = self._get_state() self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t if self.state_keys is None: @@ -149,7 +135,6 @@ class StretchRobot(StretchAPI): for name in self.cameras: before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() - images[name] = torch.from_numpy(images[name]) self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t @@ -161,8 +146,7 @@ class StretchRobot(StretchAPI): return obs_dict - def send_action(self, action: torch.Tensor) -> torch.Tensor: - # TODO(aliberts): return ndarrays instead of torch.Tensors + def send_action(self, action: np.ndarray) -> np.ndarray: if not self.is_connected: raise ConnectionError() @@ -193,16 +177,12 @@ class StretchRobot(StretchAPI): self.teleop._safety_stop(robot=self) def disconnect(self) -> None: - self.stop() + self.api.stop() if self.teleop is not None: self.teleop.gamepad_controller.stop() self.teleop.stop() - if len(self.cameras) > 0: - for cam in self.cameras.values(): - cam.disconnect() + for cam in self.cameras.values(): + cam.disconnect() self.is_connected = False - - def __del__(self): - self.disconnect()