WIP Add stretch

This commit is contained in:
Simon Alibert
2025-03-04 11:42:07 +01:00
parent e2d13ba7e4
commit 7ed7570b17
2 changed files with 83 additions and 99 deletions

View File

@@ -1,7 +1,10 @@
from dataclasses import dataclass, field
from lerobot.common.cameras.configs import CameraConfig, IntelRealSenseCameraConfig, OpenCVCameraConfig
from lerobot.common.robots.config_abc import RobotConfig
from lerobot.common.cameras import CameraConfig
from lerobot.common.cameras.intel import RealSenseCameraConfig
from lerobot.common.cameras.opencv import OpenCVCameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("stretch")
@@ -12,6 +15,7 @@ class StretchRobotConfig(RobotConfig):
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"navigation": OpenCVCameraConfig(
@@ -21,14 +25,14 @@ class StretchRobotConfig(RobotConfig):
height=720,
rotation=-90,
),
"head": IntelRealSenseCameraConfig(
"head": RealSenseCameraConfig(
name="Intel RealSense D435I",
fps=30,
width=640,
height=480,
rotation=90,
),
"wrist": IntelRealSenseCameraConfig(
"wrist": RealSenseCameraConfig(
name="Intel RealSense D405",
fps=30,
width=640,

View File

@@ -15,33 +15,55 @@
# limitations under the License.
import time
from dataclasses import replace
import numpy as np
import torch
from stretch_body.gamepad_teleop import GamePadTeleop
from stretch_body.robot import Robot as StretchAPI
from stretch_body.robot_params import RobotParams
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.datasets.utils import get_nested_item
from ..robot import Robot
from .configuration_stretch3 import StretchRobotConfig
# {lerobot_keys: stretch.api.keys}
STRETCH_MOTORS = {
"head_pan.pos": "head.head_pan.pos",
"head_tilt.pos": "head.head_tilt.pos",
"lift.pos": "lift.pos",
"arm.pos": "arm.pos",
"wrist_pitch.pos": "end_of_arm.wrist_pitch.pos",
"wrist_roll.pos": "end_of_arm.wrist_roll.pos",
"wrist_yaw.pos": "end_of_arm.wrist_yaw.pos",
"gripper.pos": "end_of_arm.stretch_gripper.pos",
"base_x.vel": "base.x_vel",
"base_y.vel": "base.y_vel",
"base_theta.vel": "base.theta_vel",
}
class StretchRobot(StretchAPI):
"""Wrapper of stretch_body.robot.Robot"""
def __init__(self, config: StretchRobotConfig | None = None, **kwargs):
super().__init__()
if config is None:
self.config = StretchRobotConfig(**kwargs)
else:
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
class Stretch3Robot(Robot):
"""[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot."""
config_class = StretchRobotConfig
name = "stretch3"
def __init__(self, config: StretchRobotConfig):
super().__init__(config)
self.config = config
self.robot_type = self.config.type
self.cameras = self.config.cameras
self.api = StretchAPI()
self.cameras = make_cameras_from_configs(config.cameras)
self.is_connected = False
self.teleop = None
self.logs = {}
self.teleop = None # TODO remove
# TODO(aliberts): test this
RobotParams.set_logging_level("WARNING")
RobotParams.set_logging_formatter("brief_console_formatter")
@@ -49,94 +71,58 @@ class StretchRobot(StretchAPI):
self.state_keys = None
self.action_keys = None
@property
def state_feature(self) -> dict:
return {
"dtype": "float32",
"shape": (len(STRETCH_MOTORS),),
"names": {"motors": list(STRETCH_MOTORS)},
}
@property
def action_feature(self) -> dict:
return self.state_feature
@property
def camera_features(self) -> dict[str, dict]:
cam_ft = {}
for cam_key, cam in self.cameras.items():
cam_ft[cam_key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
def connect(self) -> None:
self.is_connected = self.startup()
self.is_connected = self.api.startup()
if not self.is_connected:
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
raise ConnectionError()
for name in self.cameras:
self.cameras[name].connect()
self.is_connected = self.is_connected and self.cameras[name].is_connected
for cam in self.cameras.values():
cam.connect()
self.is_connected = self.is_connected and cam.is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
self.run_calibration()
self.calibrate()
def run_calibration(self) -> None:
if not self.is_homed():
self.home()
def calibrate(self) -> None:
if not self.api.is_homed():
self.api.home()
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
# TODO(aliberts): return ndarrays instead of torch.Tensors
if not self.is_connected:
raise ConnectionError()
def _get_state(self) -> dict:
status = self.api.get_status()
return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()}
if self.teleop is None:
self.teleop = GamePadTeleop(robot_instance=False)
self.teleop.startup(robot=self)
def get_observation(self) -> dict[str, np.ndarray]:
obs_dict = {}
before_read_t = time.perf_counter()
state = self.get_state()
action = self.teleop.gamepad_controller.get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
before_write_t = time.perf_counter()
self.teleop.do_motion(robot=self)
self.push_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
if self.state_keys is None:
self.state_keys = list(state)
if not record_data:
return
state = torch.as_tensor(list(state.values()))
action = torch.as_tensor(list(action.values()))
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
action_dict["action"] = action
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict, action_dict
def get_state(self) -> dict:
status = self.get_status()
return {
"head_pan.pos": status["head"]["head_pan"]["pos"],
"head_tilt.pos": status["head"]["head_tilt"]["pos"],
"lift.pos": status["lift"]["pos"],
"arm.pos": status["arm"]["pos"],
"wrist_pitch.pos": status["end_of_arm"]["wrist_pitch"]["pos"],
"wrist_roll.pos": status["end_of_arm"]["wrist_roll"]["pos"],
"wrist_yaw.pos": status["end_of_arm"]["wrist_yaw"]["pos"],
"gripper.pos": status["end_of_arm"]["stretch_gripper"]["pos"],
"base_x.vel": status["base"]["x_vel"],
"base_y.vel": status["base"]["y_vel"],
"base_theta.vel": status["base"]["theta_vel"],
}
def capture_observation(self) -> dict:
# TODO(aliberts): return ndarrays instead of torch.Tensors
before_read_t = time.perf_counter()
state = self.get_state()
state = self._get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if self.state_keys is None:
@@ -149,7 +135,6 @@ class StretchRobot(StretchAPI):
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
@@ -161,8 +146,7 @@ class StretchRobot(StretchAPI):
return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor:
# TODO(aliberts): return ndarrays instead of torch.Tensors
def send_action(self, action: np.ndarray) -> np.ndarray:
if not self.is_connected:
raise ConnectionError()
@@ -193,16 +177,12 @@ class StretchRobot(StretchAPI):
self.teleop._safety_stop(robot=self)
def disconnect(self) -> None:
self.stop()
self.api.stop()
if self.teleop is not None:
self.teleop.gamepad_controller.stop()
self.teleop.stop()
if len(self.cameras) > 0:
for cam in self.cameras.values():
cam.disconnect()
for cam in self.cameras.values():
cam.disconnect()
self.is_connected = False
def __del__(self):
self.disconnect()