diff --git a/lerobot/common/teleoperators/stretch3/teleop_stretch3.py b/lerobot/common/teleoperators/stretch3/teleop_stretch3.py index adab08428..b23bd4027 100644 --- a/lerobot/common/teleoperators/stretch3/teleop_stretch3.py +++ b/lerobot/common/teleoperators/stretch3/teleop_stretch3.py @@ -20,26 +20,35 @@ import numpy as np from stretch_body.gamepad_teleop import GamePadTeleop from stretch_body.robot_params import RobotParams -from lerobot.common.constants import OBS_IMAGES, OBS_STATE -from lerobot.common.datasets.utils import get_nested_item +from lerobot.common.errors import DeviceAlreadyConnectedError from ..teleoperator import Teleoperator from .configuration_stretch3 import Stretch3GamePadConfig -# {lerobot_keys: stretch.api.keys} -STRETCH_MOTORS = { - "head_pan.pos": "head.head_pan.pos", - "head_tilt.pos": "head.head_tilt.pos", - "lift.pos": "lift.pos", - "arm.pos": "arm.pos", - "wrist_pitch.pos": "end_of_arm.wrist_pitch.pos", - "wrist_roll.pos": "end_of_arm.wrist_roll.pos", - "wrist_yaw.pos": "end_of_arm.wrist_yaw.pos", - "gripper.pos": "end_of_arm.stretch_gripper.pos", - "base_x.vel": "base.x_vel", - "base_y.vel": "base.y_vel", - "base_theta.vel": "base.theta_vel", -} +# from stretch_body.gamepad_controller.GamePadController +GAMEPAD_BUTTONS = [ + "middle_led_ring_button_pressed", + "left_stick_x", + "left_stick_y", + "right_stick_x", + "right_stick_y", + "left_stick_button_pressed", + "right_stick_button_pressed", + "bottom_button_pressed", + "top_button_pressed", + "left_button_pressed", + "right_button_pressed", + "left_shoulder_button_pressed", + "right_shoulder_button_pressed", + "select_button_pressed", + "start_button_pressed", + "left_trigger_pulled", + "right_trigger_pulled", + "bottom_pad_pressed", + "top_pad_pressed", + "left_pad_pressed", + "right_pad_pressed", +] class Stretch3GamePad(Teleoperator): @@ -59,122 +68,53 @@ class Stretch3GamePad(Teleoperator): self.is_connected = False self.logs = {} - self.teleop = None # TODO remove - # TODO(aliberts): test this RobotParams.set_logging_level("WARNING") RobotParams.set_logging_formatter("brief_console_formatter") - self.state_keys = None - self.action_keys = None - @property - def state_feature(self) -> dict: + def action_feature(self) -> dict: return { "dtype": "float32", - "shape": (len(STRETCH_MOTORS),), - "names": {"motors": list(STRETCH_MOTORS)}, + "shape": (len(GAMEPAD_BUTTONS),), + "names": {"buttons": GAMEPAD_BUTTONS}, } @property - def action_feature(self) -> dict: - return self.state_feature - - @property - def camera_features(self) -> dict[str, dict]: - cam_ft = {} - for cam_key, cam in self.cameras.items(): - cam_ft[cam_key] = { - "shape": (cam.height, cam.width, cam.channels), - "names": ["height", "width", "channels"], - "info": None, - } - return cam_ft + def feedback_feature(self) -> dict: + return {} def connect(self) -> None: - self.is_connected = self.api.startup() - if not self.is_connected: - print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") - raise ConnectionError() + if self.is_connected: + raise DeviceAlreadyConnectedError( + "ManipulatorRobot is already connected. Do not run `robot.connect()` twice." + ) - for cam in self.cameras.values(): - cam.connect() - self.is_connected = self.is_connected and cam.is_connected - - if not self.is_connected: - print("Could not connect to the cameras, check that all cameras are plugged-in.") - raise ConnectionError() - - self.calibrate() + self.api.startup() + self.api._update_state() # Check controller can be read & written + self.api._update_modes() + self.is_connected = True def calibrate(self) -> None: - if not self.api.is_homed(): - self.api.home() - - def _get_state(self) -> dict: - status = self.api.get_status() - return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()} - - def get_observation(self) -> dict[str, np.ndarray]: - obs_dict = {} + pass + def get_action(self) -> np.ndarray: # Read Stretch state before_read_t = time.perf_counter() - state = self._get_state() + action = self.api.gamepad_controller.get_state() self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t - if self.state_keys is None: - self.state_keys = list(state) + action = np.asarray(list(action.values())) - state = np.asarray(list(state.values())) - obs_dict[OBS_STATE] = state - - # Capture images from cameras - for cam_key, cam in self.cameras.items(): - before_camread_t = time.perf_counter() - obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read() - self.logs[f"read_camera_{cam_key}_dt_s"] = cam.logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t - - return obs_dict - - def send_action(self, action: np.ndarray) -> np.ndarray: - if not self.is_connected: - raise ConnectionError() - - if self.teleop is None: - self.teleop = GamePadTeleop(robot_instance=False) - self.teleop.startup(robot=self) - - if self.action_keys is None: - dummy_action = self.teleop.gamepad_controller.get_state() - self.action_keys = list(dummy_action.keys()) - - action_dict = dict(zip(self.action_keys, action.tolist(), strict=True)) - - before_write_t = time.perf_counter() - self.teleop.do_motion(state=action_dict, robot=self) - self.push_command() - self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t - - # TODO(aliberts): return action_sent when motion is limited return action + def send_feedback(self, feedback: np.ndarray) -> None: + pass + def print_logs(self) -> None: pass # TODO(aliberts): move robot-specific logs logic here - def teleop_safety_stop(self) -> None: - if self.teleop is not None: - self.teleop._safety_stop(robot=self) - def disconnect(self) -> None: self.api.stop() - if self.teleop is not None: - self.teleop.gamepad_controller.stop() - self.teleop.stop() - - for cam in self.cameras.values(): - cam.disconnect() - self.is_connected = False