From e998dddcfa3339830348689c83a5976a9b8a6e7c Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 8 Apr 2025 10:46:29 +0200 Subject: [PATCH] Add support for feetech scs series + various fixes --- lerobot/common/motors/feetech/feetech.py | 114 ++++++++++----- lerobot/common/motors/feetech/tables.py | 141 ++++++++++++++++--- lerobot/common/motors/motors_bus.py | 170 ++++++++++------------- tests/mocks/mock_feetech.py | 10 +- tests/motors/test_feetech.py | 163 ++-------------------- 5 files changed, 291 insertions(+), 307 deletions(-) diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index f97a8a98..064927b0 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -21,19 +21,22 @@ from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_si from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value from .tables import ( - AVAILABLE_BAUDRATES, - ENCODINGS, + FIRMWARE_VERSION, MODEL_BAUDRATE_TABLE, MODEL_CONTROL_TABLE, + MODEL_ENCODING_TABLE, MODEL_NUMBER, + MODEL_NUMBER_TABLE, MODEL_RESOLUTION, - NORMALIZATION_REQUIRED, + SCAN_BAUDRATES, ) PROTOCOL_VERSION = 0 BAUDRATE = 1_000_000 DEFAULT_TIMEOUT_MS = 1000 +NORMALIZED_DATA = ["Goal_Position", "Present_Position"] + logger = logging.getLogger(__name__) @@ -80,16 +83,14 @@ class FeetechMotorsBus(MotorsBus): python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk. """ - available_baudrates = deepcopy(AVAILABLE_BAUDRATES) + available_baudrates = deepcopy(SCAN_BAUDRATES) default_timeout = DEFAULT_TIMEOUT_MS model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE) model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) - model_number_table = deepcopy(MODEL_NUMBER) + model_encoding_table = deepcopy(MODEL_ENCODING_TABLE) + model_number_table = deepcopy(MODEL_NUMBER_TABLE) model_resolution_table = deepcopy(MODEL_RESOLUTION) - normalization_required = deepcopy(NORMALIZATION_REQUIRED) - - # Feetech specific - encodings = deepcopy(ENCODINGS) + normalized_data = deepcopy(NORMALIZED_DATA) def __init__( self, @@ -140,13 +141,25 @@ class FeetechMotorsBus(MotorsBus): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value) self.write("Lock", motor, 1) - def _encode_value(self, value: int, data_name: str | None = None, n_bytes: int | None = None) -> int: - sign_bit = self.encodings.get(data_name) - return encode_sign_magnitude(value, sign_bit) if sign_bit is not None else value + def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: + for id_ in ids_values: + model = self._id_to_model(id_) + encoding_table = self.model_encoding_table.get(model) + if encoding_table and data_name in encoding_table: + sign_bit = encoding_table[data_name] + ids_values[id_] = encode_sign_magnitude(ids_values[id_], sign_bit) - def _decode_value(self, value: int, data_name: str | None = None, n_bytes: int | None = None) -> int: - sign_bit = self.encodings.get(data_name) - return decode_sign_magnitude(value, sign_bit) if sign_bit is not None else value + return ids_values + + def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: + for id_ in ids_values: + model = self._id_to_model(id_) + encoding_table = self.model_encoding_table.get(model) + if encoding_table and data_name in encoding_table: + sign_bit = encoding_table[data_name] + ids_values[id_] = decode_sign_magnitude(ids_values[id_], sign_bit) + + return ids_values @staticmethod def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]: @@ -220,15 +233,15 @@ class FeetechMotorsBus(MotorsBus): return data_list, scs.COMM_RX_CORRUPT # find packet header - for id_ in range(0, (rx_length - 1)): - if (rxpacket[id_] == 0xFF) and (rxpacket[id_ + 1] == 0xFF): + for idx in range(0, (rx_length - 1)): + if (rxpacket[idx] == 0xFF) and (rxpacket[idx + 1] == 0xFF): break - if id_ == 0: # found at the beginning of the packet + if idx == 0: # found at the beginning of the packet # calculate checksum checksum = 0 - for id_ in range(2, status_length - 1): # except header & checksum - checksum += rxpacket[id_] + for idx in range(2, status_length - 1): # except header & checksum + checksum += rxpacket[idx] checksum = scs.SCS_LOBYTE(~checksum) if rxpacket[status_length - 1] == checksum: @@ -247,34 +260,71 @@ class FeetechMotorsBus(MotorsBus): rx_length = rx_length - 2 else: # remove unnecessary packets - del rxpacket[0:id_] - rx_length = rx_length - id_ + del rxpacket[0:idx] + rx_length = rx_length - idx def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: for n_try in range(1 + num_retry): ids_status, comm = self._broadcast_ping() if self._is_comm_success(comm): break - logger.debug(f"Broadcast failed on port '{self.port}' ({n_try=})") - logger.debug(self.packet_handler.getRxPacketError(comm)) + logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})") + logger.debug(self.packet_handler.getTxRxResult(comm)) if not self._is_comm_success(comm): if raise_on_error: - raise ConnectionError(self.packet_handler.getRxPacketError(comm)) - - return ids_status if ids_status else None + raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + return ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} if ids_errors: display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()} logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}") - comm, model_numbers = self._sync_read( - "Model_Number", list(ids_status), model="scs_series", num_retry=num_retry - ) + + return self._get_model_number(list(ids_status), raise_on_error) + + def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: + # comm, major = self._sync_read(*FIRMWARE_MAJOR_VERSION, motor_ids) + # if not self._is_comm_success(comm): + # if raise_on_error: + # raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + # return + + # comm, minor = self._sync_read(*FIRMWARE_MINOR_VERSION, motor_ids) + # if not self._is_comm_success(comm): + # if raise_on_error: + # raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + # return + + # return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids} + + comm, firmware_versions = self._sync_read(*FIRMWARE_VERSION, motor_ids) if not self._is_comm_success(comm): if raise_on_error: - raise ConnectionError(self.packet_handler.getRxPacketError(comm)) + raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + return - return model_numbers if model_numbers else None + return firmware_versions + + def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: + # comm, major = self._sync_read(*MODEL_MAJOR_VERSION, motor_ids) + # if not self._is_comm_success(comm): + # if raise_on_error: + # raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + # return + + # comm, minor = self._sync_read(*MODEL_MINOR_VERSION, motor_ids) + # if not self._is_comm_success(comm): + # if raise_on_error: + # raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + # return + + # return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids} + + comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids) + if not self._is_comm_success(comm): + if raise_on_error: + raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + return return model_numbers diff --git a/lerobot/common/motors/feetech/tables.py b/lerobot/common/motors/feetech/tables.py index 5df1ea59..0fa2fa84 100644 --- a/lerobot/common/motors/feetech/tables.py +++ b/lerobot/common/motors/feetech/tables.py @@ -1,10 +1,22 @@ +FIRMWARE_MAJOR_VERSION = (0, 1) +FIRMWARE_MINOR_VERSION = (1, 1) +MODEL_MAJOR_VERSION = (3, 1) +MODEL_MINOR_VERSION = (4, 1) + +FIRMWARE_VERSION = (0, 2) +MODEL_NUMBER = (3, 2) + # See this link for STS3215 Memory Table: # https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true # data_name: (address, size_byte) -SCS_SERIES_CONTROL_TABLE = { +STS_SMS_SERIES_CONTROL_TABLE = { # EPROM - "Firmware_Version": (0, 2), - "Model_Number": (3, 2), + "Firmware_Version": FIRMWARE_VERSION, # read-only + "Model_Number": MODEL_NUMBER, # read-only + # "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only + # "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only + # "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only + # "Model_Minor_Version": MODEL_MINOR_VERSION, "ID": (5, 1), "Baud_Rate": (6, 1), "Return_Delay_Time": (7, 1), @@ -42,18 +54,75 @@ SCS_SERIES_CONTROL_TABLE = { "Goal_Speed": (46, 2), "Torque_Limit": (48, 2), "Lock": (55, 1), - "Present_Position": (56, 2), - "Present_Speed": (58, 2), - "Present_Load": (60, 2), - "Present_Voltage": (62, 1), - "Present_Temperature": (63, 1), - "Status": (65, 1), - "Moving": (66, 1), - "Present_Current": (69, 2), + "Present_Position": (56, 2), # read-only + "Present_Speed": (58, 2), # read-only + "Present_Load": (60, 2), # read-only + "Present_Voltage": (62, 1), # read-only + "Present_Temperature": (63, 1), # read-only + "Status": (65, 1), # read-only + "Moving": (66, 1), # read-only + "Present_Current": (69, 2), # read-only # Not in the Memory Table "Maximum_Acceleration": (85, 2), } +SCS_SERIES_CONTROL_TABLE = { + # EPROM + "Firmware_Version": FIRMWARE_VERSION, # read-only + "Model_Number": MODEL_NUMBER, # read-only + # "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only + # "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only + # "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only + # "Model_Minor_Version": MODEL_MINOR_VERSION, + "ID": (5, 1), + "Baud_Rate": (6, 1), + "Return_Delay": (7, 1), + "Response_Status_Level": (8, 1), + "Min_Position_Limit": (9, 2), + "Max_Position_Limit": (11, 2), + "Max_Temperature_Limit": (13, 1), + "Max_Voltage_Limit": (14, 1), + "Min_Voltage_Limit": (15, 1), + "Max_Torque_Limit": (16, 2), + "Phase": (18, 1), + "Unloading_Condition": (19, 1), + "LED_Alarm_Condition": (20, 1), + "P_Coefficient": (21, 1), + "D_Coefficient": (22, 1), + "I_Coefficient": (23, 1), + "Minimum_Startup_Force": (24, 2), + "CW_Dead_Zone": (26, 1), + "CCW_Dead_Zone": (27, 1), + "Protective_Torque": (37, 1), + "Protection_Time": (38, 1), + # SRAM + "Torque_Enable": (40, 1), + "Acceleration": (41, 1), + "Goal_Position": (42, 2), + "Running_Time": (44, 2), + "Goal_Speed": (46, 2), + "Lock": (48, 1), + "Present_Position": (56, 2), # read-only + "Present_Speed": (58, 2), # read-only + "Present_Load": (60, 2), # read-only + "Present_Voltage": (62, 1), # read-only + "Present_Temperature": (63, 1), # read-only + "Sync_Write_Flag": (64, 1), # read-only + "Status": (65, 1), # read-only + "Moving": (66, 1), # read-only +} + +STS_SMS_SERIES_BAUDRATE_TABLE = { + 0: 1_000_000, + 1: 500_000, + 2: 250_000, + 3: 128_000, + 4: 115_200, + 5: 57_600, + 6: 38_400, + 7: 19_200, +} + SCS_SERIES_BAUDRATE_TABLE = { 0: 1_000_000, 1: 500_000, @@ -66,34 +135,52 @@ SCS_SERIES_BAUDRATE_TABLE = { } MODEL_CONTROL_TABLE = { + "sts_series": STS_SMS_SERIES_CONTROL_TABLE, "scs_series": SCS_SERIES_CONTROL_TABLE, - "sts3215": SCS_SERIES_CONTROL_TABLE, + "sms_series": STS_SMS_SERIES_CONTROL_TABLE, + "sts3215": STS_SMS_SERIES_CONTROL_TABLE, + "sts3250": STS_SMS_SERIES_CONTROL_TABLE, + "scs0009": SCS_SERIES_CONTROL_TABLE, + "sm8512bl": STS_SMS_SERIES_CONTROL_TABLE, } MODEL_RESOLUTION = { - "scs_series": 4096, + "sts_series": 4096, + "sms_series": 4096, + "scs_series": 1024, "sts3215": 4096, -} - -# {model: model_number} -MODEL_NUMBER = { - "sts3215": 777, + "sts3250": 4096, + "sm8512bl": 4096, + "scs0009": 1024, } MODEL_BAUDRATE_TABLE = { + "sts_series": STS_SMS_SERIES_BAUDRATE_TABLE, + "sms_series": STS_SMS_SERIES_BAUDRATE_TABLE, "scs_series": SCS_SERIES_BAUDRATE_TABLE, - "sts3215": SCS_SERIES_BAUDRATE_TABLE, + "sm8512bl": STS_SMS_SERIES_BAUDRATE_TABLE, + "sts3215": STS_SMS_SERIES_BAUDRATE_TABLE, + "sts3250": STS_SMS_SERIES_BAUDRATE_TABLE, + "scs0009": SCS_SERIES_BAUDRATE_TABLE, } -NORMALIZATION_REQUIRED = ["Goal_Position", "Present_Position"] - # Sign-Magnitude encoding bits -ENCODINGS = { +STS_SMS_SERIES_ENCODINGS_TABLE = { "Homing_Offset": 11, "Goal_Speed": 15, } -AVAILABLE_BAUDRATES = [ +MODEL_ENCODING_TABLE = { + "sts_series": STS_SMS_SERIES_ENCODINGS_TABLE, + "sms_series": STS_SMS_SERIES_ENCODINGS_TABLE, + "scs_series": {}, + "sts3215": STS_SMS_SERIES_ENCODINGS_TABLE, + "sts3250": STS_SMS_SERIES_ENCODINGS_TABLE, + "sm8512bl": STS_SMS_SERIES_ENCODINGS_TABLE, + "scs0009": {}, +} + +SCAN_BAUDRATES = [ 4_800, 9_600, 14_400, @@ -106,3 +193,11 @@ AVAILABLE_BAUDRATES = [ 500_000, 1_000_000, ] + +# {model: model_number} TODO +MODEL_NUMBER_TABLE = { + "sts3215": 777, + "sts3250": None, + "sm8512bl": None, + "scs0009": None, +} diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 96cbb7c3..03f54452 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -25,7 +25,7 @@ from dataclasses import dataclass from enum import Enum from functools import cached_property from pprint import pformat -from typing import Protocol, TypeAlias, overload +from typing import Protocol, TypeAlias import serial from deepdiff import DeepDiff @@ -257,9 +257,10 @@ class MotorsBus(abc.ABC): default_timeout: int model_baudrate_table: dict[str, dict] model_ctrl_table: dict[str, dict] + model_encoding_table: dict[str, dict] model_number_table: dict[str, int] model_resolution_table: dict[str, int] - normalization_required: list[str] + normalized_data: list[str] def __init__( self, @@ -340,6 +341,24 @@ class MotorsBus(abc.ABC): else: raise TypeError(f"'{motor}' should be int, str.") + def _get_names_list(self, motors: str | list[str] | None) -> list[str]: + if motors is None: + return self.names + elif isinstance(motors, str): + return [motors] + elif isinstance(motors, list): + return motors.copy() + else: + raise TypeError(motors) + + def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: + if isinstance(values, (int, float)): + return {id_: values for id_ in self.ids} + elif isinstance(values, dict): + return {self.motors[motor].id: val for motor, val in values.items()} + else: + raise TypeError(f"'values' is expected to be a single value or a dict. Got {values}") + def _validate_motors(self) -> None: if len(self.ids) != len(set(self.ids)): raise ValueError(f"Some motors have the same id!\n{self}") @@ -632,15 +651,11 @@ class MotorsBus(abc.ABC): return unnormalized_values @abc.abstractmethod - def _encode_value( - self, value: int, data_name: str | None = None, n_bytes: int | None = None - ) -> dict[int, int]: + def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: pass @abc.abstractmethod - def _decode_value( - self, value: int, data_name: str | None = None, n_bytes: int | None = None - ) -> dict[int, int]: + def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: pass @staticmethod @@ -691,7 +706,7 @@ class MotorsBus(abc.ABC): if not self._is_comm_success(comm): if raise_on_error: - raise ConnectionError(self.packet_handler.getRxPacketError(comm)) + raise ConnectionError(self.packet_handler.getTxRxResult(comm)) else: return if self._is_error(error): @@ -708,87 +723,60 @@ class MotorsBus(abc.ABC): ) -> dict[int, list[int, str]] | None: pass - @overload - def sync_read( - self, data_name: str, motors: None = ..., *, normalize: bool = ..., num_retry: int = ... - ) -> dict[str, Value]: ... - @overload def sync_read( self, data_name: str, - motors: NameOrID | list[NameOrID], - *, - normalize: bool = ..., - num_retry: int = ..., - ) -> dict[NameOrID, Value]: ... - def sync_read( - self, - data_name: str, - motors: NameOrID | list[NameOrID] | None = None, + motors: str | list[str] | None = None, *, normalize: bool = True, num_retry: int = 0, - ) -> dict[NameOrID, Value]: + ) -> dict[str, Value]: if not self.is_connected: raise DeviceNotConnectedError( f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." ) - id_key_map: dict[int, NameOrID] = {} - if motors is None: - id_key_map = {m.id: name for name, m in self.motors.items()} - elif isinstance(motors, (str, int)): - id_key_map = {self._get_motor_id(motors): motors} - elif isinstance(motors, list): - id_key_map = {self._get_motor_id(m): m for m in motors} - else: - raise TypeError(motors) + names = self._get_names_list(motors) + ids = [self.motors[name].id for name in names] + models = [self.motors[name].model for name in names] - motor_ids = list(id_key_map) + if self._has_different_ctrl_tables: + assert_same_address(self.model_ctrl_table, models, data_name) - comm, ids_values = self._sync_read(data_name, motor_ids, num_retry=num_retry) + model = next(iter(models)) + addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + + comm, ids_values = self._sync_read(addr, n_bytes, ids, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( - f"Failed to sync read '{data_name}' on {motor_ids=} after {num_retry + 1} tries." + f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." f"{self.packet_handler.getTxRxResult(comm)}" ) - if normalize and data_name in self.normalization_required: + ids_values = self._decode_sign(data_name, ids_values) + + if normalize and data_name in self.normalized_data: ids_values = self._normalize(data_name, ids_values) - return {id_key_map[id_]: val for id_, val in ids_values.items()} + return {self._id_to_name(id_): value for id_, value in ids_values.items()} def _sync_read( - self, data_name: str, motor_ids: list[str], model: str | None = None, num_retry: int = 0 + self, addr: int, n_bytes: int, motor_ids: list[int], num_retry: int = 0 ) -> tuple[int, dict[int, int]]: - if self._has_different_ctrl_tables: - models = [self._id_to_model(id_) for id_ in motor_ids] - assert_same_address(self.model_ctrl_table, models, data_name) - - model = self._id_to_model(next(iter(motor_ids))) if model is None else model - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) self._setup_sync_reader(motor_ids, addr, n_bytes) - - # FIXME(aliberts, pkooij): We should probably not have to do this. - # Let's try to see if we can do with better comm status handling instead. - # self.port_handler.ser.reset_output_buffer() - # self.port_handler.ser.reset_input_buffer() - for n_try in range(1 + num_retry): comm = self.sync_reader.txRxPacket() if self._is_comm_success(comm): break - logger.debug(f"Failed to sync read '{data_name}' ({addr=} {n_bytes=}) on {motor_ids=} ({n_try=})") - logger.debug(self.packet_handler.getRxPacketError(comm)) - - values = {} - for id_ in motor_ids: - val = self.sync_reader.getData(id_, addr, n_bytes) - values[id_] = self._decode_value(val, data_name, n_bytes) + logger.debug( + f"Failed to sync read @{addr=} ({n_bytes=}) on {motor_ids=} ({n_try=}): " + + self.packet_handler.getTxRxResult(comm) + ) + values = {id_: self.sync_reader.getData(id_, addr, n_bytes) for id_ in motor_ids} return comm, values - def _setup_sync_reader(self, motor_ids: list[str], addr: int, n_bytes: int) -> None: + def _setup_sync_reader(self, motor_ids: list[int], addr: int, n_bytes: int) -> None: self.sync_reader.clearParam() self.sync_reader.start_address = addr self.sync_reader.data_length = n_bytes @@ -799,7 +787,7 @@ class MotorsBus(abc.ABC): # Would have to handle the logic of checking if a packet has been sent previously though but doable. # This could be at the cost of increase latency between the moment the data is produced by the motors and # the moment it is used by a policy. - # def _async_read(self, motor_ids: list[str], address: int, n_bytes: int): + # def _async_read(self, motor_ids: list[int], address: int, n_bytes: int): # if self.sync_reader.start_address != address or self.sync_reader.data_length != n_bytes or ...: # self._setup_sync_reader(motor_ids, address, n_bytes) # else: @@ -812,7 +800,7 @@ class MotorsBus(abc.ABC): def sync_write( self, data_name: str, - values: Value | dict[NameOrID, Value], + values: Value | dict[str, Value], *, normalize: bool = True, num_retry: int = 0, @@ -822,41 +810,36 @@ class MotorsBus(abc.ABC): f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." ) - if isinstance(values, int): - ids_values = {id_: values for id_ in self.ids} - elif isinstance(values, dict): - ids_values = {self._get_motor_id(motor): val for motor, val in values.items()} - else: - raise TypeError(f"'values' is expected to be a single value or a dict. Got {values}") + ids_values = self._get_ids_values_dict(values) + models = [self._id_to_model(id_) for id_ in ids_values] + if self._has_different_ctrl_tables: + assert_same_address(self.model_ctrl_table, models, data_name) - if normalize and data_name in self.normalization_required and self.calibration is not None: + model = next(iter(models)) + addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + + if normalize and data_name in self.normalized_data: ids_values = self._unnormalize(data_name, ids_values) - comm = self._sync_write(data_name, ids_values, num_retry=num_retry) + ids_values = self._encode_sign(data_name, ids_values) + + comm = self._sync_write(addr, n_bytes, ids_values, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." f"\n{self.packet_handler.getTxRxResult(comm)}" ) - def _sync_write(self, data_name: str, ids_values: dict[int, int], num_retry: int = 0) -> int: - if self._has_different_ctrl_tables: - models = [self._id_to_model(id_) for id_ in ids_values] - assert_same_address(self.model_ctrl_table, models, data_name) - - model = self._id_to_model(next(iter(ids_values))) - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) - ids_values = {id_: self._encode_value(value, data_name, n_bytes) for id_, value in ids_values.items()} + def _sync_write(self, addr: int, n_bytes: int, ids_values: dict[int, int], num_retry: int = 0) -> int: self._setup_sync_writer(ids_values, addr, n_bytes) - for n_try in range(1 + num_retry): comm = self.sync_writer.txPacket() if self._is_comm_success(comm): break logger.debug( - f"Failed to sync write '{data_name}' ({addr=} {n_bytes=}) with {ids_values=} ({n_try=})" + f"Failed to sync write @{addr=} ({n_bytes=}) with {ids_values=} ({n_try=}): " + + self.packet_handler.getTxRxResult(comm) ) - logger.debug(self.packet_handler.getRxPacketError(comm)) return comm @@ -869,20 +852,23 @@ class MotorsBus(abc.ABC): self.sync_writer.addParam(id_, data) def write( - self, data_name: str, motor: NameOrID, value: Value, *, normalize: bool = True, num_retry: int = 0 + self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0 ) -> None: if not self.is_connected: raise DeviceNotConnectedError( f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." ) - id_ = self._get_motor_id(motor) + id_ = self.motors[motor].id + model = self.motors[motor].model + addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) - if normalize and data_name in self.normalization_required and self.calibration is not None: - id_value = self._unnormalize(data_name, {id_: value}) - value = id_value[id_] + if normalize and data_name in self.normalized_data: + value = self._unnormalize(data_name, {id_: value})[id_] - comm, error = self._write(data_name, id_, value, num_retry=num_retry) + value = self._encode_sign(data_name, {id_: value})[id_] + + comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." @@ -894,20 +880,18 @@ class MotorsBus(abc.ABC): f"\n{self.packet_handler.getRxPacketError(error)}" ) - def _write(self, data_name: str, motor_id: int, value: int, num_retry: int = 0) -> tuple[int, int]: - model = self._id_to_model(motor_id) - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) - value = self._encode_value(value, data_name, n_bytes) + def _write( + self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0 + ) -> tuple[int, int]: data = self._split_int_to_bytes(value, n_bytes) - for n_try in range(1 + num_retry): comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) if self._is_comm_success(comm): break logger.debug( - f"Failed to write '{data_name}' ({addr=} {n_bytes=}) on {motor_id=} with '{value}' ({n_try=})" + f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): " + + self.packet_handler.getTxRxResult(comm) ) - logger.debug(self.packet_handler.getRxPacketError(comm)) return comm, error diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 1636c113..56437b02 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -5,7 +5,7 @@ import scservo_sdk as scs import serial from mock_serial import MockSerial -from lerobot.common.motors.feetech import SCS_SERIES_CONTROL_TABLE, FeetechMotorsBus +from lerobot.common.motors.feetech import STS_SMS_SERIES_CONTROL_TABLE, FeetechMotorsBus from lerobot.common.motors.feetech.feetech import patch_setPacketTimeout from .mock_serial_patch import WaitableStub @@ -297,7 +297,7 @@ class MockMotors(MockSerial): instruction packets. It is meant to test MotorsBus classes. """ - ctrl_table = SCS_SERIES_CONTROL_TABLE + ctrl_table = STS_SMS_SERIES_CONTROL_TABLE def __init__(self): super().__init__() @@ -338,12 +338,6 @@ class MockMotors(MockSerial): def build_read_stub( self, data_name: str, scs_id: int, value: int | None = None, num_invalid_try: int = 0 ) -> str: - """ - 'data_name' supported: - - Model_Number - """ - if data_name != "Model_Number": - raise NotImplementedError address, length = self.ctrl_table[data_name] read_request = MockInstructionPacket.read(scs_id, address, length) return_packet = MockStatusPacket.read(scs_id, value, length) diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 871f102e..5372c37a 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -6,7 +6,7 @@ import pytest import scservo_sdk as scs from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode -from lerobot.common.motors.feetech import MODEL_NUMBER, FeetechMotorsBus +from lerobot.common.motors.feetech import MODEL_NUMBER_TABLE, FeetechMotorsBus from lerobot.common.utils.encoding_utils import encode_sign_magnitude from tests.mocks.mock_feetech import MockMotors, MockPortHandler @@ -129,7 +129,7 @@ def test_scan_port(mock_motors): @pytest.mark.parametrize("id_", [1, 2, 3]) def test_ping(id_, mock_motors, dummy_motors): - expected_model_nb = MODEL_NUMBER[dummy_motors[f"dummy_{id_}"].model] + expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] ping_stub = mock_motors.build_ping_stub(id_) mobel_nb_stub = mock_motors.build_read_stub("Model_Number", id_, expected_model_nb) motors_bus = FeetechMotorsBus( @@ -147,7 +147,7 @@ def test_ping(id_, mock_motors, dummy_motors): def test_broadcast_ping(mock_motors, dummy_motors): models = {m.id: m.model for m in dummy_motors.values()} - expected_model_nbs = {id_: MODEL_NUMBER[model] for id_, model in models.items()} + expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()} ping_stub = mock_motors.build_broadcast_ping_stub(list(models)) mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs) motors_bus = FeetechMotorsBus( @@ -191,55 +191,7 @@ def test_sync_read_none(mock_motors, dummy_motors): (3, 4016), ], ) -def test_sync_read_by_id(id_, position, mock_motors, dummy_motors): - expected_position = {id_: position} - stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_position) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - read_position = motors_bus.sync_read("Present_Position", id_, normalize=False) - - assert mock_motors.stubs[stub_name].called - assert read_position == expected_position - - -@pytest.mark.parametrize( - "ids, positions", - [ - ([1], [1337]), - ([1, 2], [1337, 42]), - ([1, 2, 3], [1337, 42, 4016]), - ], - ids=["1 motor", "2 motors", "3 motors"], -) # fmt: skip -def test_sync_read_by_ids(ids, positions, mock_motors, dummy_motors): - assert len(ids) == len(positions) - expected_positions = dict(zip(ids, positions, strict=True)) - stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_positions) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - read_positions = motors_bus.sync_read("Present_Position", ids, normalize=False) - - assert mock_motors.stubs[stub_name].called - assert read_positions == expected_positions - - -@pytest.mark.parametrize( - "id_, position", - [ - (1, 1337), - (2, 42), - (3, 4016), - ], -) -def test_sync_read_by_name(id_, position, mock_motors, dummy_motors): +def test_sync_read_single_value(id_, position, mock_motors, dummy_motors): expected_position = {f"dummy_{id_}": position} stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position}) motors_bus = FeetechMotorsBus( @@ -263,7 +215,7 @@ def test_sync_read_by_name(id_, position, mock_motors, dummy_motors): ], ids=["1 motor", "2 motors", "3 motors"], ) # fmt: skip -def test_sync_read_by_names(ids, positions, mock_motors, dummy_motors): +def test_sync_read(ids, positions, mock_motors, dummy_motors): assert len(ids) == len(positions) names = [f"dummy_{dxl_id}" for dxl_id in ids] expected_positions = dict(zip(names, positions, strict=True)) @@ -291,9 +243,9 @@ def test_sync_read_by_names(ids, positions, mock_motors, dummy_motors): ], ) def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors): - expected_position = {1: pos} + expected_position = {"dummy_1": pos} stub_name = mock_motors.build_sync_read_stub( - "Present_Position", expected_position, num_invalid_try=num_invalid_try + "Present_Position", {1: pos}, num_invalid_try=num_invalid_try ) motors_bus = FeetechMotorsBus( port=mock_motors.port, @@ -302,11 +254,11 @@ def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy motors_bus.connect(assert_motors_exist=False) if num_retry >= num_invalid_try: - pos_dict = motors_bus.sync_read("Present_Position", 1, normalize=False, num_retry=num_retry) - assert pos_dict == {1: pos} + pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry) + assert pos_dict == expected_position else: with pytest.raises(ConnectionError): - _ = motors_bus.sync_read("Present_Position", 1, normalize=False, num_retry=num_retry) + _ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry) expected_calls = min(1 + num_retry, 1 + num_invalid_try) assert mock_motors.stubs[stub_name].calls == expected_calls @@ -335,28 +287,6 @@ def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors): assert mock_motors.stubs[stub_name].wait_called() -@pytest.mark.parametrize( - "id_, position", - [ - (1, 1337), - (2, 42), - (3, 4016), - ], -) -def test_sync_write_by_id(id_, position, mock_motors, dummy_motors): - value = {id_: position} - stub_name = mock_motors.build_sync_write_stub("Goal_Position", value) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - motors_bus.sync_write("Goal_Position", value, normalize=False) - - assert mock_motors.stubs[stub_name].wait_called() - - @pytest.mark.parametrize( "ids, positions", [ @@ -366,54 +296,7 @@ def test_sync_write_by_id(id_, position, mock_motors, dummy_motors): ], ids=["1 motor", "2 motors", "3 motors"], ) # fmt: skip -def test_sync_write_by_ids(ids, positions, mock_motors, dummy_motors): - assert len(ids) == len(positions) - values = dict(zip(ids, positions, strict=True)) - stub_name = mock_motors.build_sync_write_stub("Goal_Position", values) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - motors_bus.sync_write("Goal_Position", values, normalize=False) - - assert mock_motors.stubs[stub_name].wait_called() - - -@pytest.mark.parametrize( - "id_, position", - [ - (1, 1337), - (2, 42), - (3, 4016), - ], -) -def test_sync_write_by_name(id_, position, mock_motors, dummy_motors): - id_value = {id_: position} - stub_name = mock_motors.build_sync_write_stub("Goal_Position", id_value) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - write_value = {f"dummy_{id_}": position} - motors_bus.sync_write("Goal_Position", write_value, normalize=False) - - assert mock_motors.stubs[stub_name].wait_called() - - -@pytest.mark.parametrize( - "ids, positions", - [ - ([1], [1337]), - ([1, 2], [1337, 42]), - ([1, 2, 3], [1337, 42, 4016]), - ], - ids=["1 motor", "2 motors", "3 motors"], -) # fmt: skip -def test_sync_write_by_names(ids, positions, mock_motors, dummy_motors): +def test_sync_write(ids, positions, mock_motors, dummy_motors): assert len(ids) == len(positions) ids_values = dict(zip(ids, positions, strict=True)) stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values) @@ -438,29 +321,7 @@ def test_sync_write_by_names(ids, positions, mock_motors, dummy_motors): ("Goal_Position", 3, 42), ], ) -def test_write_by_id(data_name, dxl_id, value, mock_motors, dummy_motors): - stub_name = mock_motors.build_write_stub(data_name, dxl_id, value) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - motors_bus.write(data_name, dxl_id, value, normalize=False) - - assert mock_motors.stubs[stub_name].called - - -@pytest.mark.parametrize( - "data_name, dxl_id, value", - [ - ("Torque_Enable", 1, 0), - ("Torque_Enable", 1, 1), - ("Goal_Position", 2, 1337), - ("Goal_Position", 3, 42), - ], -) -def test_write_by_name(data_name, dxl_id, value, mock_motors, dummy_motors): +def test_write(data_name, dxl_id, value, mock_motors, dummy_motors): stub_name = mock_motors.build_write_stub(data_name, dxl_id, value) motors_bus = FeetechMotorsBus( port=mock_motors.port,