From 67e3383ffc65b8339b8e1dc89f50ac63becaecc0 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Mon, 2 Jun 2025 19:40:48 +0200 Subject: [PATCH] Refactor LeKiwi (#1136) Co-authored-by: Simon Alibert Co-authored-by: Steven Palma Co-authored-by: Steven Palma Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- examples/robots/lekiwi_client_app.py | 28 +- lerobot/common/motors/feetech/tables.py | 1 + lerobot/common/robots/lekiwi/lekiwi.py | 219 ++++++++++++--- lerobot/common/robots/lekiwi/lekiwi_client.py | 265 ++++-------------- 4 files changed, 261 insertions(+), 252 deletions(-) diff --git a/examples/robots/lekiwi_client_app.py b/examples/robots/lekiwi_client_app.py index b1bccd9c9..d13c910b3 100755 --- a/examples/robots/lekiwi_client_app.py +++ b/examples/robots/lekiwi_client_app.py @@ -16,42 +16,38 @@ import logging import time from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import hw_to_dataset_features from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig -from lerobot.common.robots.lekiwi.lekiwi_client import OBS_STATE, LeKiwiClient +from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig -from lerobot.common.teleoperators.so100 import SO100Leader, SO100LeaderConfig +from lerobot.common.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig NB_CYCLES_CLIENT_CONNECTION = 250 def main(): logging.info("Configuring Teleop Devices") - leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem58760434171") + leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem58760433331") leader_arm = SO100Leader(leader_arm_config) keyboard_config = KeyboardTeleopConfig() keyboard = KeyboardTeleop(keyboard_config) logging.info("Configuring LeKiwi Client") - robot_config = LeKiwiClientConfig(remote_ip="192.0.2.42", id="lekiwi") + robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") robot = LeKiwiClient(robot_config) logging.info("Creating LeRobot Dataset") - # The observations that we get are expected to be in body frame (x,y,theta) - obs_dict = {f"{OBS_STATE}." + key: value for key, value in robot.state_feature.items()} - # The actions that we send are expected to be in wheel frame (motor encoders) - act_dict = {"action." + key: value for key, value in robot.action_feature.items()} + action_features = hw_to_dataset_features(robot.action_features, "action") + obs_features = hw_to_dataset_features(robot.observation_features, "observation") + dataset_features = {**action_features, **obs_features} - features_dict = { - **act_dict, - **obs_dict, - **robot.camera_features, - } dataset = LeRobotDataset.create( repo_id="user/lekiwi" + str(int(time.time())), fps=10, - features=features_dict, + features=dataset_features, + robot_type=robot.name, ) logging.info("Connecting Teleop Devices") @@ -76,10 +72,10 @@ def main(): observation = robot.get_observation() frame = {**action_sent, **observation} - frame.update({"task": "Dummy Example Task Dataset"}) + task = "Dummy Example Task Dataset" logging.info("Saved a frame into the dataset") - dataset.add_frame(frame) + dataset.add_frame(frame, task) i += 1 logging.info("Disconnecting Teleop Devices and LeKiwi Client") diff --git a/lerobot/common/motors/feetech/tables.py b/lerobot/common/motors/feetech/tables.py index 6fa9e97de..0a2f2659f 100644 --- a/lerobot/common/motors/feetech/tables.py +++ b/lerobot/common/motors/feetech/tables.py @@ -207,6 +207,7 @@ MODEL_BAUDRATE_TABLE = { STS_SMS_SERIES_ENCODINGS_TABLE = { "Homing_Offset": 11, "Goal_Velocity": 15, + "Present_Velocity": 15, } MODEL_ENCODING_TABLE = { diff --git a/lerobot/common/robots/lekiwi/lekiwi.py b/lerobot/common/robots/lekiwi/lekiwi.py index 85298e21b..d670d9ae9 100644 --- a/lerobot/common/robots/lekiwi/lekiwi.py +++ b/lerobot/common/robots/lekiwi/lekiwi.py @@ -16,9 +16,12 @@ import logging import time +from functools import cached_property from itertools import chain from typing import Any +import numpy as np + from lerobot.common.cameras.utils import make_cameras_from_configs from lerobot.common.constants import OBS_IMAGES, OBS_STATE from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError @@ -71,34 +74,35 @@ class LeKiwi(Robot): self.cameras = make_cameras_from_configs(config.cameras) @property - def state_feature(self) -> dict: - state_ft = { - "arm_shoulder_pan": {"dtype": "float32"}, - "arm_shoulder_lift": {"dtype": "float32"}, - "arm_elbow_flex": {"dtype": "float32"}, - "arm_wrist_flex": {"dtype": "float32"}, - "arm_wrist_roll": {"dtype": "float32"}, - "arm_gripper": {"dtype": "float32"}, - "base_left_wheel": {"dtype": "float32"}, - "base_right_wheel": {"dtype": "float32"}, - "base_back_wheel": {"dtype": "float32"}, + def _state_ft(self) -> dict[str, type]: + return dict.fromkeys( + ( + "arm_shoulder_pan.pos", + "arm_shoulder_lift.pos", + "arm_elbow_flex.pos", + "arm_wrist_flex.pos", + "arm_wrist_roll.pos", + "arm_gripper.pos", + "x.vel", + "y.vel", + "theta.vel", + ), + float, + ) + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras } - return state_ft - @property - def action_feature(self) -> dict: - return self.state_feature + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._state_ft, **self._cameras_ft} - @property - def camera_features(self) -> dict[str, dict]: - cam_ft = {} - for cam_key, cam in self.cameras.items(): - cam_ft[cam_key] = { - "shape": (cam.height, cam.width, cam.channels), - "names": ["height", "width", "channels"], - "info": None, - } - return cam_ft + @cached_property + def action_features(self) -> dict[str, type]: + return self._state_ft @property def is_connected(self) -> bool: @@ -189,6 +193,142 @@ class LeKiwi(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @staticmethod + def _degps_to_raw(degps: float) -> int: + steps_per_deg = 4096.0 / 360.0 + speed_in_steps = degps * steps_per_deg + speed_int = int(round(speed_in_steps)) + # Cap the value to fit within signed 16-bit range (-32768 to 32767) + if speed_int > 0x7FFF: + speed_int = 0x7FFF # 32767 -> maximum positive value + elif speed_int < -0x8000: + speed_int = -0x8000 # -32768 -> minimum negative value + return speed_int + + @staticmethod + def _raw_to_degps(raw_speed: int) -> float: + steps_per_deg = 4096.0 / 360.0 + magnitude = raw_speed + degps = magnitude / steps_per_deg + return degps + + def _body_to_wheel_raw( + self, + x: float, + y: float, + theta: float, + wheel_radius: float = 0.05, + base_radius: float = 0.125, + max_raw: int = 3000, + ) -> dict: + """ + Convert desired body-frame velocities into wheel raw commands. + + Parameters: + x_cmd : Linear velocity in x (m/s). + y_cmd : Linear velocity in y (m/s). + theta_cmd : Rotational velocity (deg/s). + wheel_radius: Radius of each wheel (meters). + base_radius : Distance from the center of rotation to each wheel (meters). + max_raw : Maximum allowed raw command (ticks) per wheel. + + Returns: + A dictionary with wheel raw commands: + {"base_left_wheel": value, "base_back_wheel": value, "base_right_wheel": value}. + + Notes: + - Internally, the method converts theta_cmd to rad/s for the kinematics. + - The raw command is computed from the wheels angular speed in deg/s + using _degps_to_raw(). If any command exceeds max_raw, all commands + are scaled down proportionally. + """ + # Convert rotational velocity from deg/s to rad/s. + theta_rad = theta * (np.pi / 180.0) + # Create the body velocity vector [x, y, theta_rad]. + velocity_vector = np.array([x, y, theta_rad]) + + # Define the wheel mounting angles with a -90° offset. + angles = np.radians(np.array([240, 120, 0]) - 90) + # Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed. + # The third column (base_radius) accounts for the effect of rotation. + m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) + + # Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s). + wheel_linear_speeds = m.dot(velocity_vector) + wheel_angular_speeds = wheel_linear_speeds / wheel_radius + + # Convert wheel angular speeds from rad/s to deg/s. + wheel_degps = wheel_angular_speeds * (180.0 / np.pi) + + # Scaling + steps_per_deg = 4096.0 / 360.0 + raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps] + max_raw_computed = max(raw_floats) + if max_raw_computed > max_raw: + scale = max_raw / max_raw_computed + wheel_degps = wheel_degps * scale + + # Convert each wheel’s angular speed (deg/s) to a raw integer. + wheel_raw = [self._degps_to_raw(deg) for deg in wheel_degps] + + return { + "base_left_wheel": wheel_raw[0], + "base_back_wheel": wheel_raw[1], + "base_right_wheel": wheel_raw[2], + } + + def _wheel_raw_to_body( + self, + left_wheel_speed, + back_wheel_speed, + right_wheel_speed, + wheel_radius: float = 0.05, + base_radius: float = 0.125, + ) -> dict[str, Any]: + """ + Convert wheel raw command feedback back into body-frame velocities. + + Parameters: + wheel_raw : Vector with raw wheel commands ("base_left_wheel", "base_back_wheel", "base_right_wheel"). + wheel_radius: Radius of each wheel (meters). + base_radius : Distance from the robot center to each wheel (meters). + + Returns: + A dict (x_cmd, y_cmd, theta_cmd) where: + OBS_STATE.x_cmd : Linear velocity in x (m/s). + OBS_STATE.y_cmd : Linear velocity in y (m/s). + OBS_STATE.theta_cmd : Rotational velocity in deg/s. + """ + + # Convert each raw command back to an angular speed in deg/s. + wheel_degps = np.array( + [ + self._raw_to_degps(left_wheel_speed), + self._raw_to_degps(back_wheel_speed), + self._raw_to_degps(right_wheel_speed), + ] + ) + + # Convert from deg/s to rad/s. + wheel_radps = wheel_degps * (np.pi / 180.0) + # Compute each wheel’s linear speed (m/s) from its angular speed. + wheel_linear_speeds = wheel_radps * wheel_radius + + # Define the wheel mounting angles with a -90° offset. + angles = np.radians(np.array([240, 120, 0]) - 90) + m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) + + # Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds. + m_inv = np.linalg.inv(m) + velocity_vector = m_inv.dot(wheel_linear_speeds) + x, y, theta_rad = velocity_vector + theta = theta_rad * (180.0 / np.pi) + return { + "x.vel": x, + "y.vel": y, + "theta.vel": theta, + } # m/s and deg/s + def get_observation(self) -> dict[str, Any]: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -196,9 +336,19 @@ class LeKiwi(Robot): # Read actuators position for arm and vel for base start = time.perf_counter() arm_pos = self.bus.sync_read("Present_Position", self.arm_motors) - base_vel = self.bus.sync_read("Present_Velocity", self.base_motors) - obs_dict = {**arm_pos, **base_vel} - obs_dict = {f"{OBS_STATE}." + key: value for key, value in obs_dict.items()} + base_wheel_vel = self.bus.sync_read("Present_Velocity", self.base_motors) + + base_vel = self._wheel_raw_to_body( + base_wheel_vel["base_left_wheel"], + base_wheel_vel["base_back_wheel"], + base_wheel_vel["base_right_wheel"], + ) + + arm_state = {f"{k}.pos": v for k, v in arm_pos.items()} + + flat_states = {**arm_state, **base_vel} + + obs_dict = {f"{OBS_STATE}": flat_states} dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read state: {dt_ms:.1f}ms") @@ -227,8 +377,12 @@ class LeKiwi(Robot): if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - arm_goal_pos = {k: v for k, v in action.items() if k in self.arm_motors} - base_goal_vel = {k: v for k, v in action.items() if k in self.base_motors} + arm_goal_pos = {k: v for k, v in action.items() if k.endswith(".pos")} + base_goal_vel = {k: v for k, v in action.items() if k.endswith(".vel")} + + base_wheel_goal_vel = self._body_to_wheel_raw( + base_goal_vel["x.vel"], base_goal_vel["y.vel"], base_goal_vel["theta.vel"] + ) # Cap goal position when too far away from present position. # /!\ Slower fps expected due to reading from the follower. @@ -239,8 +393,9 @@ class LeKiwi(Robot): arm_goal_pos = arm_safe_goal_pos # Send goal position to the actuators - self.bus.sync_write("Goal_Position", arm_goal_pos) - self.bus.sync_write("Goal_Velocity", base_goal_vel) + arm_goal_pos_raw = {k.replace(".pos", ""): v for k, v in arm_goal_pos.items()} + self.bus.sync_write("Goal_Position", arm_goal_pos_raw) + self.bus.sync_write("Goal_Velocity", base_wheel_goal_vel) return {**arm_goal_pos, **base_goal_vel} diff --git a/lerobot/common/robots/lekiwi/lekiwi_client.py b/lerobot/common/robots/lekiwi/lekiwi_client.py index a4e791000..a7d73d535 100644 --- a/lerobot/common/robots/lekiwi/lekiwi_client.py +++ b/lerobot/common/robots/lekiwi/lekiwi_client.py @@ -15,6 +15,7 @@ import base64 import json import logging +from functools import cached_property from typing import Any, Dict, Optional, Tuple import cv2 @@ -54,8 +55,7 @@ class LeKiwiClient(Robot): self.last_frames = {} - self.last_remote_arm_state = {} - self.last_remote_base_state = {"base_left_wheel": 0, "base_back_wheel": 0, "base_right_wheel": 0} + self.last_remote_state = {} # Define three speed levels and a current index self.speed_levels = [ @@ -68,53 +68,41 @@ class LeKiwiClient(Robot): self._is_connected = False self.logs = {} - @property - def state_feature(self) -> dict: - state_ft = { - "arm_shoulder_pan": {"shape": (1,), "info": None, "dtype": "float32"}, - "arm_shoulder_lift": {"shape": (1,), "info": None, "dtype": "float32"}, - "arm_elbow_flex": {"shape": (1,), "info": None, "dtype": "float32"}, - "arm_wrist_flex": {"shape": (1,), "info": None, "dtype": "float32"}, - "arm_wrist_roll": {"shape": (1,), "info": None, "dtype": "float32"}, - "arm_gripper": {"shape": (1,), "info": None, "dtype": "float32"}, - "x_cmd": {"shape": (1,), "info": None, "dtype": "float32"}, - "y_cmd": {"shape": (1,), "info": None, "dtype": "float32"}, - "theta_cmd": {"shape": (1,), "info": None, "dtype": "float32"}, - } - return state_ft + @cached_property + def _state_ft(self) -> dict[str, type]: + return dict.fromkeys( + ( + "arm_shoulder_pan.pos", + "arm_shoulder_lift.pos", + "arm_elbow_flex.pos", + "arm_wrist_flex.pos", + "arm_wrist_roll.pos", + "arm_gripper.pos", + "x.vel", + "y.vel", + "theta.vel", + ), + float, + ) - @property - def action_feature(self) -> dict: - action_ft = { - "arm_shoulder_pan": {"shape": (1,), "info": None, "dtype": "float32"}, - "arm_shoulder_lift": {"shape": (1,), "info": None, "dtype": "float32"}, - "arm_elbow_flex": {"shape": (1,), "info": None, "dtype": "float32"}, - "arm_wrist_flex": {"shape": (1,), "info": None, "dtype": "float32"}, - "arm_wrist_roll": {"shape": (1,), "info": None, "dtype": "float32"}, - "arm_gripper": {"shape": (1,), "info": None, "dtype": "float32"}, - "base_left_wheel": {"shape": (1,), "info": None, "dtype": "float32"}, - "base_right_wheel": {"shape": (1,), "info": None, "dtype": "float32"}, - "base_back_wheel": {"shape": (1,), "info": None, "dtype": "float32"}, - } - return action_ft + @cached_property + def _state_order(self) -> tuple[str, ...]: + return tuple(self._state_ft.keys()) - @property - def camera_features(self) -> dict[str, dict]: - cam_ft = { - f"{OBS_IMAGES}.front": { - "shape": (480, 640, 3), - "names": ["height", "width", "channels"], - "info": None, - "dtype": "image", - }, - f"{OBS_IMAGES}.wrist": { - "shape": (480, 640, 3), - "names": ["height", "width", "channels"], - "dtype": "image", - "info": None, - }, + @cached_property + def _cameras_ft(self) -> dict[str, tuple]: + return { + "front": (480, 640, 3), + "wrist": (640, 480, 3), } - return cam_ft + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._state_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._state_ft @property def is_connected(self) -> bool: @@ -154,130 +142,6 @@ class LeKiwiClient(Robot): def calibrate(self) -> None: pass - @staticmethod - def _degps_to_raw(degps: float) -> int: - steps_per_deg = 4096.0 / 360.0 - speed_in_steps = degps * steps_per_deg - speed_int = int(round(speed_in_steps)) - # Cap the value to fit within signed 16-bit range (-32768 to 32767) - if speed_int > 0x7FFF: - speed_int = 0x7FFF # 32767 -> maximum positive value - elif speed_int < -0x8000: - speed_int = -0x8000 # -32768 -> minimum negative value - return speed_int - - @staticmethod - def _raw_to_degps(raw_speed: int) -> float: - steps_per_deg = 4096.0 / 360.0 - magnitude = raw_speed - degps = magnitude / steps_per_deg - return degps - - def _body_to_wheel_raw( - self, - x_cmd: float, - y_cmd: float, - theta_cmd: float, - wheel_radius: float = 0.05, - base_radius: float = 0.125, - max_raw: int = 3000, - ) -> dict: - """ - Convert desired body-frame velocities into wheel raw commands. - - Parameters: - x_cmd : Linear velocity in x (m/s). - y_cmd : Linear velocity in y (m/s). - theta_cmd : Rotational velocity (deg/s). - wheel_radius: Radius of each wheel (meters). - base_radius : Distance from the center of rotation to each wheel (meters). - max_raw : Maximum allowed raw command (ticks) per wheel. - - Returns: - A dictionary with wheel raw commands: - {"base_left_wheel": value, "base_back_wheel": value, "base_right_wheel": value}. - - Notes: - - Internally, the method converts theta_cmd to rad/s for the kinematics. - - The raw command is computed from the wheels angular speed in deg/s - using _degps_to_raw(). If any command exceeds max_raw, all commands - are scaled down proportionally. - """ - # Convert rotational velocity from deg/s to rad/s. - theta_rad = theta_cmd * (np.pi / 180.0) - # Create the body velocity vector [x, y, theta_rad]. - velocity_vector = np.array([x_cmd, y_cmd, theta_rad]) - - # Define the wheel mounting angles with a -90° offset. - angles = np.radians(np.array([240, 120, 0]) - 90) - # Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed. - # The third column (base_radius) accounts for the effect of rotation. - m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) - - # Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s). - wheel_linear_speeds = m.dot(velocity_vector) - wheel_angular_speeds = wheel_linear_speeds / wheel_radius - - # Convert wheel angular speeds from rad/s to deg/s. - wheel_degps = wheel_angular_speeds * (180.0 / np.pi) - - # Scaling - steps_per_deg = 4096.0 / 360.0 - raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps] - max_raw_computed = max(raw_floats) - if max_raw_computed > max_raw: - scale = max_raw / max_raw_computed - wheel_degps = wheel_degps * scale - - # Convert each wheel’s angular speed (deg/s) to a raw integer. - wheel_raw = [self._degps_to_raw(deg) for deg in wheel_degps] - - return { - "base_left_wheel": wheel_raw[0], - "base_back_wheel": wheel_raw[1], - "base_right_wheel": wheel_raw[2], - } - - def _wheel_raw_to_body( - self, wheel_raw: dict[str, Any], wheel_radius: float = 0.05, base_radius: float = 0.125 - ) -> dict[str, Any]: - """ - Convert wheel raw command feedback back into body-frame velocities. - - Parameters: - wheel_raw : Vector with raw wheel commands ("base_left_wheel", "base_back_wheel", "base_right_wheel"). - wheel_radius: Radius of each wheel (meters). - base_radius : Distance from the robot center to each wheel (meters). - - Returns: - A dict (x_cmd, y_cmd, theta_cmd) where: - OBS_STATE.x_cmd : Linear velocity in x (m/s). - OBS_STATE.y_cmd : Linear velocity in y (m/s). - OBS_STATE.theta_cmd : Rotational velocity in deg/s. - """ - - # Convert each raw command back to an angular speed in deg/s. - wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(v)) for _, v in wheel_raw.items()]) - # Convert from deg/s to rad/s. - wheel_radps = wheel_degps * (np.pi / 180.0) - # Compute each wheel’s linear speed (m/s) from its angular speed. - wheel_linear_speeds = wheel_radps * wheel_radius - - # Define the wheel mounting angles with a -90° offset. - angles = np.radians(np.array([240, 120, 0]) - 90) - m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) - - # Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds. - m_inv = np.linalg.inv(m) - velocity_vector = m_inv.dot(wheel_linear_speeds) - x_cmd, y_cmd, theta_rad = velocity_vector - theta_cmd = theta_rad * (180.0 / np.pi) - return { - f"{OBS_STATE}.x_cmd": x_cmd * 1000, - f"{OBS_STATE}.y_cmd": y_cmd * 1000, - f"{OBS_STATE}.theta_cmd": theta_cmd, - } # Convert to mm/s - def _poll_and_get_latest_message(self) -> Optional[str]: """Polls the ZMQ socket for a limited time and returns the latest message string.""" poller = zmq.Poller() @@ -331,25 +195,24 @@ class LeKiwiClient(Robot): def _remote_state_from_obs( self, observation: Dict[str, Any] - ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, Any]]: - """Extracts frames, speed, and arm state from the parsed observation.""" + ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """Extracts frames, and state from the parsed observation.""" + flat_state = observation[OBS_STATE] - # Separate image and state data - image_observation = {k: v for k, v in observation.items() if k.startswith(OBS_IMAGES)} - state_observation = {k: v for k, v in observation.items() if k.startswith(OBS_STATE)} + state_vec = np.array( + [flat_state.get(k, 0.0) for k in self._state_order], + dtype=np.float32, + ) # Decode images + image_observation = {k: v for k, v in observation.items() if k.startswith(OBS_IMAGES)} current_frames: Dict[str, np.ndarray] = {} for cam_name, image_b64 in image_observation.items(): frame = self._decode_image_from_b64(image_b64) if frame is not None: current_frames[cam_name] = frame - # Extract state components - current_arm_state = {k: v for k, v in state_observation.items() if k.startswith(f"{OBS_STATE}.arm")} - current_base_state = {k: v for k, v in state_observation.items() if k.startswith(f"{OBS_STATE}.base")} - - return current_frames, current_arm_state, current_base_state + return current_frames, {"observation.state": state_vec} def _get_data(self) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, Any]]: """ @@ -365,27 +228,26 @@ class LeKiwiClient(Robot): # 2. If no message, return cached data if latest_message_str is None: - return self.last_frames, self.last_remote_arm_state, self.last_remote_base_state + return self.last_frames, self.last_remote_state # 3. Parse the JSON message observation = self._parse_observation_json(latest_message_str) # 4. If JSON parsing failed, return cached data if observation is None: - return self.last_frames, self.last_remote_arm_state, self.last_remote_base_state + return self.last_frames, self.last_remote_state # 5. Process the valid observation data try: - new_frames, new_arm_state, new_base_state = self._remote_state_from_obs(observation) + new_frames, new_state = self._remote_state_from_obs(observation) except Exception as e: logging.error(f"Error processing observation data, serving last observation: {e}") - return self.last_frames, self.last_remote_arm_state, self.last_remote_base_state + return self.last_frames, self.last_remote_state self.last_frames = new_frames - self.last_remote_arm_state = new_arm_state - self.last_remote_base_state = new_base_state + self.last_remote_state = new_state - return new_frames, new_arm_state, new_base_state + return new_frames, new_state def get_observation(self) -> dict[str, Any]: """ @@ -396,13 +258,7 @@ class LeKiwiClient(Robot): if not self._is_connected: raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.") - frames, remote_arm_state, remote_base_state = self._get_data() - remote_body_state = self._wheel_raw_to_body(remote_base_state) - - obs_dict = {**remote_arm_state, **remote_body_state} - - # TODO(Steven): Remove this when it is possible to record a non-numpy array value - obs_dict = {k: np.array([v], dtype=np.float32) for k, v in obs_dict.items()} + frames, obs_dict = self._get_data() # Loop over each configured camera for cam_name, frame in frames.items(): @@ -413,7 +269,7 @@ class LeKiwiClient(Robot): return obs_dict - def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray): + def _from_keyboard_to_base_action(self, pressed_keys: np.ndarray): # Speed control if self.teleop_keys["speed_up"] in pressed_keys: self.speed_index = min(self.speed_index + 1, 2) @@ -439,7 +295,11 @@ class LeKiwiClient(Robot): theta_cmd += theta_speed if self.teleop_keys["rotate_right"] in pressed_keys: theta_cmd -= theta_speed - return self._body_to_wheel_raw(x_cmd, y_cmd, theta_cmd) + return { + "x.vel": x_cmd, + "y.vel": y_cmd, + "theta.vel": theta_cmd, + } def configure(self): pass @@ -461,26 +321,23 @@ class LeKiwiClient(Robot): "ManipulatorRobot is not connected. You need to run `robot.connect()`." ) - goal_pos = {} - common_keys = [ key for key in action - if key in (motor.replace("arm_", "") for motor, _ in self.action_feature.items()) + if key in (motor.replace("arm_", "") for motor, _ in self.action_features.items()) ] arm_actions = {"arm_" + arm_motor: action[arm_motor] for arm_motor in common_keys} - goal_pos = arm_actions keyboard_keys = np.array(list(set(action.keys()) - set(common_keys))) - wheel_actions = self._from_keyboard_to_wheel_action(keyboard_keys) - goal_pos = {**arm_actions, **wheel_actions} + base_actions = self._from_keyboard_to_base_action(keyboard_keys) + goal_pos = {**arm_actions, **base_actions} self.zmq_cmd_socket.send_string(json.dumps(goal_pos)) # action is in motor space # TODO(Steven): Remove the np conversion when it is possible to record a non-numpy array value - goal_pos = {"action." + k: np.array([v], dtype=np.float32) for k, v in goal_pos.items()} - return goal_pos + actions = np.array([goal_pos.get(k, 0.0) for k in self._state_order], dtype=np.float32) + return {"action.state": actions} def disconnect(self): """Cleans ZMQ comms"""