forked from tangger/lerobot
WIP Add stretch
This commit is contained in:
@@ -1,7 +1,10 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.common.cameras.configs import CameraConfig, IntelRealSenseCameraConfig, OpenCVCameraConfig
|
from lerobot.common.cameras import CameraConfig
|
||||||
from lerobot.common.robots.config_abc import RobotConfig
|
from lerobot.common.cameras.intel import RealSenseCameraConfig
|
||||||
|
from lerobot.common.cameras.opencv import OpenCVCameraConfig
|
||||||
|
|
||||||
|
from ..config import RobotConfig
|
||||||
|
|
||||||
|
|
||||||
@RobotConfig.register_subclass("stretch")
|
@RobotConfig.register_subclass("stretch")
|
||||||
@@ -12,6 +15,7 @@ class StretchRobotConfig(RobotConfig):
|
|||||||
# the number of motors in your follower arms.
|
# the number of motors in your follower arms.
|
||||||
max_relative_target: int | None = None
|
max_relative_target: int | None = None
|
||||||
|
|
||||||
|
# cameras
|
||||||
cameras: dict[str, CameraConfig] = field(
|
cameras: dict[str, CameraConfig] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"navigation": OpenCVCameraConfig(
|
"navigation": OpenCVCameraConfig(
|
||||||
@@ -21,14 +25,14 @@ class StretchRobotConfig(RobotConfig):
|
|||||||
height=720,
|
height=720,
|
||||||
rotation=-90,
|
rotation=-90,
|
||||||
),
|
),
|
||||||
"head": IntelRealSenseCameraConfig(
|
"head": RealSenseCameraConfig(
|
||||||
name="Intel RealSense D435I",
|
name="Intel RealSense D435I",
|
||||||
fps=30,
|
fps=30,
|
||||||
width=640,
|
width=640,
|
||||||
height=480,
|
height=480,
|
||||||
rotation=90,
|
rotation=90,
|
||||||
),
|
),
|
||||||
"wrist": IntelRealSenseCameraConfig(
|
"wrist": RealSenseCameraConfig(
|
||||||
name="Intel RealSense D405",
|
name="Intel RealSense D405",
|
||||||
fps=30,
|
fps=30,
|
||||||
width=640,
|
width=640,
|
||||||
|
|||||||
@@ -15,33 +15,55 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import replace
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from stretch_body.gamepad_teleop import GamePadTeleop
|
from stretch_body.gamepad_teleop import GamePadTeleop
|
||||||
from stretch_body.robot import Robot as StretchAPI
|
from stretch_body.robot import Robot as StretchAPI
|
||||||
from stretch_body.robot_params import RobotParams
|
from stretch_body.robot_params import RobotParams
|
||||||
|
|
||||||
|
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.common.datasets.utils import get_nested_item
|
||||||
|
|
||||||
|
from ..robot import Robot
|
||||||
from .configuration_stretch3 import StretchRobotConfig
|
from .configuration_stretch3 import StretchRobotConfig
|
||||||
|
|
||||||
|
# {lerobot_keys: stretch.api.keys}
|
||||||
|
STRETCH_MOTORS = {
|
||||||
|
"head_pan.pos": "head.head_pan.pos",
|
||||||
|
"head_tilt.pos": "head.head_tilt.pos",
|
||||||
|
"lift.pos": "lift.pos",
|
||||||
|
"arm.pos": "arm.pos",
|
||||||
|
"wrist_pitch.pos": "end_of_arm.wrist_pitch.pos",
|
||||||
|
"wrist_roll.pos": "end_of_arm.wrist_roll.pos",
|
||||||
|
"wrist_yaw.pos": "end_of_arm.wrist_yaw.pos",
|
||||||
|
"gripper.pos": "end_of_arm.stretch_gripper.pos",
|
||||||
|
"base_x.vel": "base.x_vel",
|
||||||
|
"base_y.vel": "base.y_vel",
|
||||||
|
"base_theta.vel": "base.theta_vel",
|
||||||
|
}
|
||||||
|
|
||||||
class StretchRobot(StretchAPI):
|
|
||||||
"""Wrapper of stretch_body.robot.Robot"""
|
|
||||||
|
|
||||||
def __init__(self, config: StretchRobotConfig | None = None, **kwargs):
|
class Stretch3Robot(Robot):
|
||||||
super().__init__()
|
"""[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot."""
|
||||||
if config is None:
|
|
||||||
self.config = StretchRobotConfig(**kwargs)
|
|
||||||
else:
|
|
||||||
# Overwrite config arguments using kwargs
|
|
||||||
self.config = replace(config, **kwargs)
|
|
||||||
|
|
||||||
|
config_class = StretchRobotConfig
|
||||||
|
name = "stretch3"
|
||||||
|
|
||||||
|
def __init__(self, config: StretchRobotConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
self.robot_type = self.config.type
|
self.robot_type = self.config.type
|
||||||
self.cameras = self.config.cameras
|
|
||||||
|
self.api = StretchAPI()
|
||||||
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
|
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self.teleop = None
|
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
|
self.teleop = None # TODO remove
|
||||||
|
|
||||||
# TODO(aliberts): test this
|
# TODO(aliberts): test this
|
||||||
RobotParams.set_logging_level("WARNING")
|
RobotParams.set_logging_level("WARNING")
|
||||||
RobotParams.set_logging_formatter("brief_console_formatter")
|
RobotParams.set_logging_formatter("brief_console_formatter")
|
||||||
@@ -49,94 +71,58 @@ class StretchRobot(StretchAPI):
|
|||||||
self.state_keys = None
|
self.state_keys = None
|
||||||
self.action_keys = None
|
self.action_keys = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_feature(self) -> dict:
|
||||||
|
return {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(STRETCH_MOTORS),),
|
||||||
|
"names": {"motors": list(STRETCH_MOTORS)},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_feature(self) -> dict:
|
||||||
|
return self.state_feature
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_features(self) -> dict[str, dict]:
|
||||||
|
cam_ft = {}
|
||||||
|
for cam_key, cam in self.cameras.items():
|
||||||
|
cam_ft[cam_key] = {
|
||||||
|
"shape": (cam.height, cam.width, cam.channels),
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": None,
|
||||||
|
}
|
||||||
|
return cam_ft
|
||||||
|
|
||||||
def connect(self) -> None:
|
def connect(self) -> None:
|
||||||
self.is_connected = self.startup()
|
self.is_connected = self.api.startup()
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
|
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
|
||||||
raise ConnectionError()
|
raise ConnectionError()
|
||||||
|
|
||||||
for name in self.cameras:
|
for cam in self.cameras.values():
|
||||||
self.cameras[name].connect()
|
cam.connect()
|
||||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
self.is_connected = self.is_connected and cam.is_connected
|
||||||
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||||
raise ConnectionError()
|
raise ConnectionError()
|
||||||
|
|
||||||
self.run_calibration()
|
self.calibrate()
|
||||||
|
|
||||||
def run_calibration(self) -> None:
|
def calibrate(self) -> None:
|
||||||
if not self.is_homed():
|
if not self.api.is_homed():
|
||||||
self.home()
|
self.api.home()
|
||||||
|
|
||||||
def teleop_step(
|
def _get_state(self) -> dict:
|
||||||
self, record_data=False
|
status = self.api.get_status()
|
||||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()}
|
||||||
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
|
||||||
if not self.is_connected:
|
|
||||||
raise ConnectionError()
|
|
||||||
|
|
||||||
if self.teleop is None:
|
def get_observation(self) -> dict[str, np.ndarray]:
|
||||||
self.teleop = GamePadTeleop(robot_instance=False)
|
obs_dict = {}
|
||||||
self.teleop.startup(robot=self)
|
|
||||||
|
|
||||||
before_read_t = time.perf_counter()
|
before_read_t = time.perf_counter()
|
||||||
state = self.get_state()
|
state = self._get_state()
|
||||||
action = self.teleop.gamepad_controller.get_state()
|
|
||||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
|
||||||
|
|
||||||
before_write_t = time.perf_counter()
|
|
||||||
self.teleop.do_motion(robot=self)
|
|
||||||
self.push_command()
|
|
||||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
|
||||||
|
|
||||||
if self.state_keys is None:
|
|
||||||
self.state_keys = list(state)
|
|
||||||
|
|
||||||
if not record_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
state = torch.as_tensor(list(state.values()))
|
|
||||||
action = torch.as_tensor(list(action.values()))
|
|
||||||
|
|
||||||
# Capture images from cameras
|
|
||||||
images = {}
|
|
||||||
for name in self.cameras:
|
|
||||||
before_camread_t = time.perf_counter()
|
|
||||||
images[name] = self.cameras[name].async_read()
|
|
||||||
images[name] = torch.from_numpy(images[name])
|
|
||||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
|
||||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
|
||||||
|
|
||||||
# Populate output dictionaries
|
|
||||||
obs_dict, action_dict = {}, {}
|
|
||||||
obs_dict["observation.state"] = state
|
|
||||||
action_dict["action"] = action
|
|
||||||
for name in self.cameras:
|
|
||||||
obs_dict[f"observation.images.{name}"] = images[name]
|
|
||||||
|
|
||||||
return obs_dict, action_dict
|
|
||||||
|
|
||||||
def get_state(self) -> dict:
|
|
||||||
status = self.get_status()
|
|
||||||
return {
|
|
||||||
"head_pan.pos": status["head"]["head_pan"]["pos"],
|
|
||||||
"head_tilt.pos": status["head"]["head_tilt"]["pos"],
|
|
||||||
"lift.pos": status["lift"]["pos"],
|
|
||||||
"arm.pos": status["arm"]["pos"],
|
|
||||||
"wrist_pitch.pos": status["end_of_arm"]["wrist_pitch"]["pos"],
|
|
||||||
"wrist_roll.pos": status["end_of_arm"]["wrist_roll"]["pos"],
|
|
||||||
"wrist_yaw.pos": status["end_of_arm"]["wrist_yaw"]["pos"],
|
|
||||||
"gripper.pos": status["end_of_arm"]["stretch_gripper"]["pos"],
|
|
||||||
"base_x.vel": status["base"]["x_vel"],
|
|
||||||
"base_y.vel": status["base"]["y_vel"],
|
|
||||||
"base_theta.vel": status["base"]["theta_vel"],
|
|
||||||
}
|
|
||||||
|
|
||||||
def capture_observation(self) -> dict:
|
|
||||||
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
|
||||||
before_read_t = time.perf_counter()
|
|
||||||
state = self.get_state()
|
|
||||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||||
|
|
||||||
if self.state_keys is None:
|
if self.state_keys is None:
|
||||||
@@ -149,7 +135,6 @@ class StretchRobot(StretchAPI):
|
|||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
before_camread_t = time.perf_counter()
|
before_camread_t = time.perf_counter()
|
||||||
images[name] = self.cameras[name].async_read()
|
images[name] = self.cameras[name].async_read()
|
||||||
images[name] = torch.from_numpy(images[name])
|
|
||||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
|
||||||
@@ -161,8 +146,7 @@ class StretchRobot(StretchAPI):
|
|||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
def send_action(self, action: np.ndarray) -> np.ndarray:
|
||||||
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise ConnectionError()
|
raise ConnectionError()
|
||||||
|
|
||||||
@@ -193,16 +177,12 @@ class StretchRobot(StretchAPI):
|
|||||||
self.teleop._safety_stop(robot=self)
|
self.teleop._safety_stop(robot=self)
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
def disconnect(self) -> None:
|
||||||
self.stop()
|
self.api.stop()
|
||||||
if self.teleop is not None:
|
if self.teleop is not None:
|
||||||
self.teleop.gamepad_controller.stop()
|
self.teleop.gamepad_controller.stop()
|
||||||
self.teleop.stop()
|
self.teleop.stop()
|
||||||
|
|
||||||
if len(self.cameras) > 0:
|
for cam in self.cameras.values():
|
||||||
for cam in self.cameras.values():
|
cam.disconnect()
|
||||||
cam.disconnect()
|
|
||||||
|
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self.disconnect()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user