diff --git a/lerobot/common/robot_devices/motors/feetech.py b/lerobot/common/robot_devices/motors/feetech.py index 030ab399f..3d7bb790c 100644 --- a/lerobot/common/robot_devices/motors/feetech.py +++ b/lerobot/common/robot_devices/motors/feetech.py @@ -7,17 +7,6 @@ from copy import deepcopy import numpy as np import tqdm -from scservo_sdk import ( - COMM_SUCCESS, - SCS_HIBYTE, - SCS_HIWORD, - SCS_LOBYTE, - SCS_LOWORD, - GroupSyncRead, - GroupSyncWrite, - PacketHandler, - PortHandler, -) from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError from lerobot.common.utils.utils import capture_timestamp_utc @@ -144,24 +133,29 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str return steps -def convert_to_bytes(value, bytes): +def convert_to_bytes(value, bytes, mock=False): + if mock: + return value + + import scservo_sdk as scs + # Note: No need to convert back into unsigned int, since this byte preprocessing # already handles it for us. if bytes == 1: data = [ - SCS_LOBYTE(SCS_LOWORD(value)), + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), ] elif bytes == 2: data = [ - SCS_LOBYTE(SCS_LOWORD(value)), - SCS_HIBYTE(SCS_LOWORD(value)), + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), + scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), ] elif bytes == 4: data = [ - SCS_LOBYTE(SCS_LOWORD(value)), - SCS_HIBYTE(SCS_LOWORD(value)), - SCS_LOBYTE(SCS_HIWORD(value)), - SCS_HIBYTE(SCS_HIWORD(value)), + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), + scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), + scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), + scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), ] else: raise NotImplementedError( @@ -281,9 +275,11 @@ class FeetechMotorsBus: motors: dict[str, tuple[int, str]], extra_model_control_table: dict[str, list[tuple]] | None = None, extra_model_resolution: dict[str, int] | None = None, + mock=False, ): self.port = port self.motors = motors + self.mock = mock self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) if extra_model_control_table: @@ -319,8 +315,13 @@ class FeetechMotorsBus: f"FeetechMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice." ) - self.port_handler = PortHandler(self.port) - self.packet_handler = PacketHandler(PROTOCOL_VERSION) + if self.mock: + import tests.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + + self.port_handler = scs.PortHandler(self.port) + self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION) try: if not self.port_handler.openPort(): @@ -338,10 +339,17 @@ class FeetechMotorsBus: self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS) def reconnect(self): - self.port_handler = PortHandler(self.port) - self.packet_handler = PacketHandler(PROTOCOL_VERSION) + if self.mock: + import tests.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + + self.port_handler = scs.PortHandler(self.port) + self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION) + if not self.port_handler.openPort(): raise OSError(f"Failed to open port '{self.port}'.") + self.is_connected = True def are_motors_configured(self): @@ -658,6 +666,11 @@ class FeetechMotorsBus: return values def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): + if self.mock: + import tests.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + return_list = True if not isinstance(motor_ids, list): return_list = False @@ -665,16 +678,16 @@ class FeetechMotorsBus: assert_same_address(self.model_ctrl_table, self.motor_models, data_name) addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] - group = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes) + group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes) for idx in motor_ids: group.addParam(idx) for _ in range(num_retry): comm = group.txRxPacket() - if comm == COMM_SUCCESS: + if comm == scs.COMM_SUCCESS: break - if comm != COMM_SUCCESS: + if comm != scs.COMM_SUCCESS: raise ConnectionError( f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: " f"{self.packet_handler.getTxRxResult(comm)}" @@ -691,6 +704,11 @@ class FeetechMotorsBus: return values[0] def read(self, data_name, motor_names: str | list[str] | None = None): + if self.mock: + import tests.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + if not self.is_connected: raise RobotDeviceNotConnectedError( f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." @@ -717,16 +735,18 @@ class FeetechMotorsBus: if data_name not in self.group_readers: # create new group reader - self.group_readers[group_key] = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes) + self.group_readers[group_key] = scs.GroupSyncRead( + self.port_handler, self.packet_handler, addr, bytes + ) for idx in motor_ids: self.group_readers[group_key].addParam(idx) for _ in range(NUM_READ_RETRY): comm = self.group_readers[group_key].txRxPacket() - if comm == COMM_SUCCESS: + if comm == scs.COMM_SUCCESS: break - if comm != COMM_SUCCESS: + if comm != scs.COMM_SUCCESS: raise ConnectionError( f"Read failed due to communication error on port {self.port} for group_key {group_key}: " f"{self.packet_handler.getTxRxResult(comm)}" @@ -760,6 +780,11 @@ class FeetechMotorsBus: return values def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): + if self.mock: + import tests.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + if not isinstance(motor_ids, list): motor_ids = [motor_ids] if not isinstance(values, list): @@ -767,17 +792,17 @@ class FeetechMotorsBus: assert_same_address(self.model_ctrl_table, motor_models, data_name) addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] - group = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes) + group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes) for idx, value in zip(motor_ids, values, strict=True): - data = convert_to_bytes(value, bytes) + data = convert_to_bytes(value, bytes, self.mock) group.addParam(idx, data) for _ in range(num_retry): comm = group.txPacket() - if comm == COMM_SUCCESS: + if comm == scs.COMM_SUCCESS: break - if comm != COMM_SUCCESS: + if comm != scs.COMM_SUCCESS: raise ConnectionError( f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: " f"{self.packet_handler.getTxRxResult(comm)}" @@ -791,6 +816,11 @@ class FeetechMotorsBus: start_time = time.perf_counter() + if self.mock: + import tests.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + if motor_names is None: motor_names = self.motor_names @@ -820,19 +850,19 @@ class FeetechMotorsBus: init_group = data_name not in self.group_readers if init_group: - self.group_writers[group_key] = GroupSyncWrite( + self.group_writers[group_key] = scs.GroupSyncWrite( self.port_handler, self.packet_handler, addr, bytes ) for idx, value in zip(motor_ids, values, strict=True): - data = convert_to_bytes(value, bytes) + data = convert_to_bytes(value, bytes, self.mock) if init_group: self.group_writers[group_key].addParam(idx, data) else: self.group_writers[group_key].changeParam(idx, data) comm = self.group_writers[group_key].txPacket() - if comm != COMM_SUCCESS: + if comm != scs.COMM_SUCCESS: raise ConnectionError( f"Write failed due to communication error on port {self.port} for group_key {group_key}: " f"{self.packet_handler.getTxRxResult(comm)}" diff --git a/tests/mock_scservo_sdk.py b/tests/mock_scservo_sdk.py new file mode 100644 index 000000000..06c4283a6 --- /dev/null +++ b/tests/mock_scservo_sdk.py @@ -0,0 +1,87 @@ +"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration +and testing code logic that requires hardware and devices (e.g. robot arms, cameras) + +Warning: These mocked versions are minimalist. They do not exactly mock every behaviors +from the original classes and functions (e.g. return types might be None instead of boolean). +""" + +# from dynamixel_sdk import COMM_SUCCESS + +DEFAULT_BAUDRATE = 1_000_000 +COMM_SUCCESS = 0 # tx or rx packet communication success + + +def convert_to_bytes(value, bytes): + # TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform + # `convert_bytes_to_value` + del bytes # unused + return value + + +class PortHandler: + def __init__(self, port): + self.port = port + # factory default baudrate + self.baudrate = DEFAULT_BAUDRATE + + def openPort(self): # noqa: N802 + return True + + def closePort(self): # noqa: N802 + pass + + def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802 + del timeout_ms # unused + + def getBaudRate(self): # noqa: N802 + return self.baudrate + + def setBaudRate(self, baudrate): # noqa: N802 + self.baudrate = baudrate + + +class PacketHandler: + def __init__(self, protocol_version): + del protocol_version # unused + # Use packet_handler.data to communicate across Read and Write + self.data = {} + + +class GroupSyncRead: + def __init__(self, port_handler, packet_handler, address, bytes): + self.packet_handler = packet_handler + + def addParam(self, motor_index): # noqa: N802 + if motor_index not in self.packet_handler.data: + # Initialize motor default values + self.packet_handler.data[motor_index] = { + # Key (int) are from X_SERIES_CONTROL_TABLE + 7: motor_index, # ID + 8: DEFAULT_BAUDRATE, # Baud_rate + 10: 0, # Drive_Mode + 64: 0, # Torque_Enable + # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144 + # For other joints, 2560 will be autocorrected to be in calibration range + 132: 2560, # Present_Position + } + + def txRxPacket(self): # noqa: N802 + return COMM_SUCCESS + + def getData(self, index, address, bytes): # noqa: N802 + return self.packet_handler.data[index][address] + + +class GroupSyncWrite: + def __init__(self, port_handler, packet_handler, address, bytes): + self.packet_handler = packet_handler + self.address = address + + def addParam(self, index, data): # noqa: N802 + self.changeParam(index, data) + + def txPacket(self): # noqa: N802 + return COMM_SUCCESS + + def changeParam(self, index, data): # noqa: N802 + self.packet_handler.data[index][self.address] = data diff --git a/tests/test_motors.py b/tests/test_motors.py index 14cb3b478..2f668926c 100644 --- a/tests/test_motors.py +++ b/tests/test_motors.py @@ -52,12 +52,24 @@ def test_configure_motors_all_ids_1(request, motor_type, mock): if mock: request.getfixturevalue("patch_builtins_input") + if motor_type == "dynamixel": + # see X_SERIES_BAUDRATE_TABLE + smaller_baudrate = 9_600 + smaller_baudrate_value = 0 + elif motor_type == "feetech": + # see SCS_SERIES_BAUDRATE_TABLE + smaller_baudrate = 19_200 + smaller_baudrate_value = 7 + else: + raise ValueError(motor_type) + input("Are you sure you want to re-configure the motors? Press enter to continue...") # This test expect the configuration was already correct. motors_bus = make_motors_bus(motor_type, mock=mock) motors_bus.connect() - motors_bus.write("Baud_Rate", [0] * len(motors_bus.motors)) - motors_bus.set_bus_baudrate(9_600) + motors_bus.write("Baud_Rate", [smaller_baudrate_value] * len(motors_bus.motors)) + + motors_bus.set_bus_baudrate(smaller_baudrate) motors_bus.write("ID", [1] * len(motors_bus.motors)) del motors_bus diff --git a/tests/utils.py b/tests/utils.py index 0c4b94d89..da5c06dc1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -52,7 +52,7 @@ for motor_type in available_motors: OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) INTELREALSENSE_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_CAMERA_INDEX", 128422271614)) -DYNAMIXEL_PORT = "/dev/tty.usbmodem575E0032081" +DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081") DYNAMIXEL_MOTORS = { "shoulder_pan": [1, "xl430-w250"], "shoulder_lift": [2, "xl430-w250"], @@ -62,6 +62,16 @@ DYNAMIXEL_MOTORS = { "gripper": [6, "xl330-m288"], } +FEETECH_PORT = os.environ.get("LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971") +FEETECH_MOTORS = { + "shoulder_pan": [1, "sts3215"], + "shoulder_lift": [2, "sts3215"], + "elbow_flex": [3, "sts3215"], + "wrist_flex": [4, "sts3215"], + "wrist_roll": [5, "sts3215"], + "gripper": [6, "sts3215"], +} + def require_x86_64_kernel(func): """ @@ -277,7 +287,7 @@ def make_robot(robot_type: str, overrides: list[str] | None = None, mock=False) # Explicitely add mock argument to the cameras and set it to true # TODO(rcadene, aliberts): redesign when we drop hydra - if robot_type == "koch": + if robot_type in ["koch", "so100", "moss"]: overrides.append("+leader_arms.main.mock=true") overrides.append("+follower_arms.main.mock=true") if "~cameras" not in overrides: @@ -338,5 +348,12 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: motors = kwargs.pop("motors", DYNAMIXEL_MOTORS) return DynamixelMotorsBus(port, motors, **kwargs) + elif motor_type == "feetech": + from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus + + port = kwargs.pop("port", FEETECH_PORT) + motors = kwargs.pop("motors", FEETECH_MOTORS) + return FeetechMotorsBus(port, motors, **kwargs) + else: raise ValueError(f"The motor type '{motor_type}' is not valid.")