WIP Add stretch

This commit is contained in:
Simon Alibert
2025-03-04 11:42:07 +01:00
parent e2d13ba7e4
commit 7ed7570b17
2 changed files with 83 additions and 99 deletions

View File

@@ -1,7 +1,10 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.common.cameras.configs import CameraConfig, IntelRealSenseCameraConfig, OpenCVCameraConfig from lerobot.common.cameras import CameraConfig
from lerobot.common.robots.config_abc import RobotConfig from lerobot.common.cameras.intel import RealSenseCameraConfig
from lerobot.common.cameras.opencv import OpenCVCameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("stretch") @RobotConfig.register_subclass("stretch")
@@ -12,6 +15,7 @@ class StretchRobotConfig(RobotConfig):
# the number of motors in your follower arms. # the number of motors in your follower arms.
max_relative_target: int | None = None max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field( cameras: dict[str, CameraConfig] = field(
default_factory=lambda: { default_factory=lambda: {
"navigation": OpenCVCameraConfig( "navigation": OpenCVCameraConfig(
@@ -21,14 +25,14 @@ class StretchRobotConfig(RobotConfig):
height=720, height=720,
rotation=-90, rotation=-90,
), ),
"head": IntelRealSenseCameraConfig( "head": RealSenseCameraConfig(
name="Intel RealSense D435I", name="Intel RealSense D435I",
fps=30, fps=30,
width=640, width=640,
height=480, height=480,
rotation=90, rotation=90,
), ),
"wrist": IntelRealSenseCameraConfig( "wrist": RealSenseCameraConfig(
name="Intel RealSense D405", name="Intel RealSense D405",
fps=30, fps=30,
width=640, width=640,

View File

@@ -15,33 +15,55 @@
# limitations under the License. # limitations under the License.
import time import time
from dataclasses import replace
import numpy as np
import torch import torch
from stretch_body.gamepad_teleop import GamePadTeleop from stretch_body.gamepad_teleop import GamePadTeleop
from stretch_body.robot import Robot as StretchAPI from stretch_body.robot import Robot as StretchAPI
from stretch_body.robot_params import RobotParams from stretch_body.robot_params import RobotParams
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.datasets.utils import get_nested_item
from ..robot import Robot
from .configuration_stretch3 import StretchRobotConfig from .configuration_stretch3 import StretchRobotConfig
# {lerobot_keys: stretch.api.keys}
STRETCH_MOTORS = {
"head_pan.pos": "head.head_pan.pos",
"head_tilt.pos": "head.head_tilt.pos",
"lift.pos": "lift.pos",
"arm.pos": "arm.pos",
"wrist_pitch.pos": "end_of_arm.wrist_pitch.pos",
"wrist_roll.pos": "end_of_arm.wrist_roll.pos",
"wrist_yaw.pos": "end_of_arm.wrist_yaw.pos",
"gripper.pos": "end_of_arm.stretch_gripper.pos",
"base_x.vel": "base.x_vel",
"base_y.vel": "base.y_vel",
"base_theta.vel": "base.theta_vel",
}
class StretchRobot(StretchAPI):
"""Wrapper of stretch_body.robot.Robot"""
def __init__(self, config: StretchRobotConfig | None = None, **kwargs): class Stretch3Robot(Robot):
super().__init__() """[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot."""
if config is None:
self.config = StretchRobotConfig(**kwargs)
else:
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
config_class = StretchRobotConfig
name = "stretch3"
def __init__(self, config: StretchRobotConfig):
super().__init__(config)
self.config = config
self.robot_type = self.config.type self.robot_type = self.config.type
self.cameras = self.config.cameras
self.api = StretchAPI()
self.cameras = make_cameras_from_configs(config.cameras)
self.is_connected = False self.is_connected = False
self.teleop = None
self.logs = {} self.logs = {}
self.teleop = None # TODO remove
# TODO(aliberts): test this # TODO(aliberts): test this
RobotParams.set_logging_level("WARNING") RobotParams.set_logging_level("WARNING")
RobotParams.set_logging_formatter("brief_console_formatter") RobotParams.set_logging_formatter("brief_console_formatter")
@@ -49,94 +71,58 @@ class StretchRobot(StretchAPI):
self.state_keys = None self.state_keys = None
self.action_keys = None self.action_keys = None
@property
def state_feature(self) -> dict:
return {
"dtype": "float32",
"shape": (len(STRETCH_MOTORS),),
"names": {"motors": list(STRETCH_MOTORS)},
}
@property
def action_feature(self) -> dict:
return self.state_feature
@property
def camera_features(self) -> dict[str, dict]:
cam_ft = {}
for cam_key, cam in self.cameras.items():
cam_ft[cam_key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
def connect(self) -> None: def connect(self) -> None:
self.is_connected = self.startup() self.is_connected = self.api.startup()
if not self.is_connected: if not self.is_connected:
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
raise ConnectionError() raise ConnectionError()
for name in self.cameras: for cam in self.cameras.values():
self.cameras[name].connect() cam.connect()
self.is_connected = self.is_connected and self.cameras[name].is_connected self.is_connected = self.is_connected and cam.is_connected
if not self.is_connected: if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.") print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError() raise ConnectionError()
self.run_calibration() self.calibrate()
def run_calibration(self) -> None: def calibrate(self) -> None:
if not self.is_homed(): if not self.api.is_homed():
self.home() self.api.home()
def teleop_step( def _get_state(self) -> dict:
self, record_data=False status = self.api.get_status()
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()}
# TODO(aliberts): return ndarrays instead of torch.Tensors
if not self.is_connected:
raise ConnectionError()
if self.teleop is None: def get_observation(self) -> dict[str, np.ndarray]:
self.teleop = GamePadTeleop(robot_instance=False) obs_dict = {}
self.teleop.startup(robot=self)
before_read_t = time.perf_counter() before_read_t = time.perf_counter()
state = self.get_state() state = self._get_state()
action = self.teleop.gamepad_controller.get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
before_write_t = time.perf_counter()
self.teleop.do_motion(robot=self)
self.push_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
if self.state_keys is None:
self.state_keys = list(state)
if not record_data:
return
state = torch.as_tensor(list(state.values()))
action = torch.as_tensor(list(action.values()))
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
action_dict["action"] = action
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict, action_dict
def get_state(self) -> dict:
status = self.get_status()
return {
"head_pan.pos": status["head"]["head_pan"]["pos"],
"head_tilt.pos": status["head"]["head_tilt"]["pos"],
"lift.pos": status["lift"]["pos"],
"arm.pos": status["arm"]["pos"],
"wrist_pitch.pos": status["end_of_arm"]["wrist_pitch"]["pos"],
"wrist_roll.pos": status["end_of_arm"]["wrist_roll"]["pos"],
"wrist_yaw.pos": status["end_of_arm"]["wrist_yaw"]["pos"],
"gripper.pos": status["end_of_arm"]["stretch_gripper"]["pos"],
"base_x.vel": status["base"]["x_vel"],
"base_y.vel": status["base"]["y_vel"],
"base_theta.vel": status["base"]["theta_vel"],
}
def capture_observation(self) -> dict:
# TODO(aliberts): return ndarrays instead of torch.Tensors
before_read_t = time.perf_counter()
state = self.get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if self.state_keys is None: if self.state_keys is None:
@@ -149,7 +135,6 @@ class StretchRobot(StretchAPI):
for name in self.cameras: for name in self.cameras:
before_camread_t = time.perf_counter() before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
@@ -161,8 +146,7 @@ class StretchRobot(StretchAPI):
return obs_dict return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor: def send_action(self, action: np.ndarray) -> np.ndarray:
# TODO(aliberts): return ndarrays instead of torch.Tensors
if not self.is_connected: if not self.is_connected:
raise ConnectionError() raise ConnectionError()
@@ -193,16 +177,12 @@ class StretchRobot(StretchAPI):
self.teleop._safety_stop(robot=self) self.teleop._safety_stop(robot=self)
def disconnect(self) -> None: def disconnect(self) -> None:
self.stop() self.api.stop()
if self.teleop is not None: if self.teleop is not None:
self.teleop.gamepad_controller.stop() self.teleop.gamepad_controller.stop()
self.teleop.stop() self.teleop.stop()
if len(self.cameras) > 0: for cam in self.cameras.values():
for cam in self.cameras.values(): cam.disconnect()
cam.disconnect()
self.is_connected = False self.is_connected = False
def __del__(self):
self.disconnect()