Add support for feetech scs series + various fixes

This commit is contained in:
Simon Alibert
2025-04-08 10:46:29 +02:00
parent 99c0938b42
commit e998dddcfa
5 changed files with 291 additions and 307 deletions

View File

@@ -21,19 +21,22 @@ from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_si
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
from .tables import ( from .tables import (
AVAILABLE_BAUDRATES, FIRMWARE_VERSION,
ENCODINGS,
MODEL_BAUDRATE_TABLE, MODEL_BAUDRATE_TABLE,
MODEL_CONTROL_TABLE, MODEL_CONTROL_TABLE,
MODEL_ENCODING_TABLE,
MODEL_NUMBER, MODEL_NUMBER,
MODEL_NUMBER_TABLE,
MODEL_RESOLUTION, MODEL_RESOLUTION,
NORMALIZATION_REQUIRED, SCAN_BAUDRATES,
) )
PROTOCOL_VERSION = 0 PROTOCOL_VERSION = 0
BAUDRATE = 1_000_000 BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000 DEFAULT_TIMEOUT_MS = 1000
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -80,16 +83,14 @@ class FeetechMotorsBus(MotorsBus):
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk. python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
""" """
available_baudrates = deepcopy(AVAILABLE_BAUDRATES) available_baudrates = deepcopy(SCAN_BAUDRATES)
default_timeout = DEFAULT_TIMEOUT_MS default_timeout = DEFAULT_TIMEOUT_MS
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE) model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
model_number_table = deepcopy(MODEL_NUMBER) model_encoding_table = deepcopy(MODEL_ENCODING_TABLE)
model_number_table = deepcopy(MODEL_NUMBER_TABLE)
model_resolution_table = deepcopy(MODEL_RESOLUTION) model_resolution_table = deepcopy(MODEL_RESOLUTION)
normalization_required = deepcopy(NORMALIZATION_REQUIRED) normalized_data = deepcopy(NORMALIZED_DATA)
# Feetech specific
encodings = deepcopy(ENCODINGS)
def __init__( def __init__(
self, self,
@@ -140,13 +141,25 @@ class FeetechMotorsBus(MotorsBus):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value) self.write("Torque_Enable", motor, TorqueMode.ENABLED.value)
self.write("Lock", motor, 1) self.write("Lock", motor, 1)
def _encode_value(self, value: int, data_name: str | None = None, n_bytes: int | None = None) -> int: def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
sign_bit = self.encodings.get(data_name) for id_ in ids_values:
return encode_sign_magnitude(value, sign_bit) if sign_bit is not None else value model = self._id_to_model(id_)
encoding_table = self.model_encoding_table.get(model)
if encoding_table and data_name in encoding_table:
sign_bit = encoding_table[data_name]
ids_values[id_] = encode_sign_magnitude(ids_values[id_], sign_bit)
def _decode_value(self, value: int, data_name: str | None = None, n_bytes: int | None = None) -> int: return ids_values
sign_bit = self.encodings.get(data_name)
return decode_sign_magnitude(value, sign_bit) if sign_bit is not None else value def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
model = self._id_to_model(id_)
encoding_table = self.model_encoding_table.get(model)
if encoding_table and data_name in encoding_table:
sign_bit = encoding_table[data_name]
ids_values[id_] = decode_sign_magnitude(ids_values[id_], sign_bit)
return ids_values
@staticmethod @staticmethod
def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]: def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]:
@@ -220,15 +233,15 @@ class FeetechMotorsBus(MotorsBus):
return data_list, scs.COMM_RX_CORRUPT return data_list, scs.COMM_RX_CORRUPT
# find packet header # find packet header
for id_ in range(0, (rx_length - 1)): for idx in range(0, (rx_length - 1)):
if (rxpacket[id_] == 0xFF) and (rxpacket[id_ + 1] == 0xFF): if (rxpacket[idx] == 0xFF) and (rxpacket[idx + 1] == 0xFF):
break break
if id_ == 0: # found at the beginning of the packet if idx == 0: # found at the beginning of the packet
# calculate checksum # calculate checksum
checksum = 0 checksum = 0
for id_ in range(2, status_length - 1): # except header & checksum for idx in range(2, status_length - 1): # except header & checksum
checksum += rxpacket[id_] checksum += rxpacket[idx]
checksum = scs.SCS_LOBYTE(~checksum) checksum = scs.SCS_LOBYTE(~checksum)
if rxpacket[status_length - 1] == checksum: if rxpacket[status_length - 1] == checksum:
@@ -247,34 +260,71 @@ class FeetechMotorsBus(MotorsBus):
rx_length = rx_length - 2 rx_length = rx_length - 2
else: else:
# remove unnecessary packets # remove unnecessary packets
del rxpacket[0:id_] del rxpacket[0:idx]
rx_length = rx_length - id_ rx_length = rx_length - idx
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
for n_try in range(1 + num_retry): for n_try in range(1 + num_retry):
ids_status, comm = self._broadcast_ping() ids_status, comm = self._broadcast_ping()
if self._is_comm_success(comm): if self._is_comm_success(comm):
break break
logger.debug(f"Broadcast failed on port '{self.port}' ({n_try=})") logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
logger.debug(self.packet_handler.getRxPacketError(comm)) logger.debug(self.packet_handler.getTxRxResult(comm))
if not self._is_comm_success(comm): if not self._is_comm_success(comm):
if raise_on_error: if raise_on_error:
raise ConnectionError(self.packet_handler.getRxPacketError(comm)) raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return ids_status if ids_status else None
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors: if ids_errors:
display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()} display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()}
logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}") logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}")
comm, model_numbers = self._sync_read(
"Model_Number", list(ids_status), model="scs_series", num_retry=num_retry return self._get_model_number(list(ids_status), raise_on_error)
)
def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
# comm, major = self._sync_read(*FIRMWARE_MAJOR_VERSION, motor_ids)
# if not self._is_comm_success(comm):
# if raise_on_error:
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
# return
# comm, minor = self._sync_read(*FIRMWARE_MINOR_VERSION, motor_ids)
# if not self._is_comm_success(comm):
# if raise_on_error:
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
# return
# return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids}
comm, firmware_versions = self._sync_read(*FIRMWARE_VERSION, motor_ids)
if not self._is_comm_success(comm): if not self._is_comm_success(comm):
if raise_on_error: if raise_on_error:
raise ConnectionError(self.packet_handler.getRxPacketError(comm)) raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return model_numbers if model_numbers else None return firmware_versions
def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
# comm, major = self._sync_read(*MODEL_MAJOR_VERSION, motor_ids)
# if not self._is_comm_success(comm):
# if raise_on_error:
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
# return
# comm, minor = self._sync_read(*MODEL_MINOR_VERSION, motor_ids)
# if not self._is_comm_success(comm):
# if raise_on_error:
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
# return
# return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids}
comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids)
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return model_numbers return model_numbers

View File

@@ -1,10 +1,22 @@
FIRMWARE_MAJOR_VERSION = (0, 1)
FIRMWARE_MINOR_VERSION = (1, 1)
MODEL_MAJOR_VERSION = (3, 1)
MODEL_MINOR_VERSION = (4, 1)
FIRMWARE_VERSION = (0, 2)
MODEL_NUMBER = (3, 2)
# See this link for STS3215 Memory Table: # See this link for STS3215 Memory Table:
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true # https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
# data_name: (address, size_byte) # data_name: (address, size_byte)
SCS_SERIES_CONTROL_TABLE = { STS_SMS_SERIES_CONTROL_TABLE = {
# EPROM # EPROM
"Firmware_Version": (0, 2), "Firmware_Version": FIRMWARE_VERSION, # read-only
"Model_Number": (3, 2), "Model_Number": MODEL_NUMBER, # read-only
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
# "Model_Minor_Version": MODEL_MINOR_VERSION,
"ID": (5, 1), "ID": (5, 1),
"Baud_Rate": (6, 1), "Baud_Rate": (6, 1),
"Return_Delay_Time": (7, 1), "Return_Delay_Time": (7, 1),
@@ -42,18 +54,75 @@ SCS_SERIES_CONTROL_TABLE = {
"Goal_Speed": (46, 2), "Goal_Speed": (46, 2),
"Torque_Limit": (48, 2), "Torque_Limit": (48, 2),
"Lock": (55, 1), "Lock": (55, 1),
"Present_Position": (56, 2), "Present_Position": (56, 2), # read-only
"Present_Speed": (58, 2), "Present_Speed": (58, 2), # read-only
"Present_Load": (60, 2), "Present_Load": (60, 2), # read-only
"Present_Voltage": (62, 1), "Present_Voltage": (62, 1), # read-only
"Present_Temperature": (63, 1), "Present_Temperature": (63, 1), # read-only
"Status": (65, 1), "Status": (65, 1), # read-only
"Moving": (66, 1), "Moving": (66, 1), # read-only
"Present_Current": (69, 2), "Present_Current": (69, 2), # read-only
# Not in the Memory Table # Not in the Memory Table
"Maximum_Acceleration": (85, 2), "Maximum_Acceleration": (85, 2),
} }
SCS_SERIES_CONTROL_TABLE = {
# EPROM
"Firmware_Version": FIRMWARE_VERSION, # read-only
"Model_Number": MODEL_NUMBER, # read-only
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
# "Model_Minor_Version": MODEL_MINOR_VERSION,
"ID": (5, 1),
"Baud_Rate": (6, 1),
"Return_Delay": (7, 1),
"Response_Status_Level": (8, 1),
"Min_Position_Limit": (9, 2),
"Max_Position_Limit": (11, 2),
"Max_Temperature_Limit": (13, 1),
"Max_Voltage_Limit": (14, 1),
"Min_Voltage_Limit": (15, 1),
"Max_Torque_Limit": (16, 2),
"Phase": (18, 1),
"Unloading_Condition": (19, 1),
"LED_Alarm_Condition": (20, 1),
"P_Coefficient": (21, 1),
"D_Coefficient": (22, 1),
"I_Coefficient": (23, 1),
"Minimum_Startup_Force": (24, 2),
"CW_Dead_Zone": (26, 1),
"CCW_Dead_Zone": (27, 1),
"Protective_Torque": (37, 1),
"Protection_Time": (38, 1),
# SRAM
"Torque_Enable": (40, 1),
"Acceleration": (41, 1),
"Goal_Position": (42, 2),
"Running_Time": (44, 2),
"Goal_Speed": (46, 2),
"Lock": (48, 1),
"Present_Position": (56, 2), # read-only
"Present_Speed": (58, 2), # read-only
"Present_Load": (60, 2), # read-only
"Present_Voltage": (62, 1), # read-only
"Present_Temperature": (63, 1), # read-only
"Sync_Write_Flag": (64, 1), # read-only
"Status": (65, 1), # read-only
"Moving": (66, 1), # read-only
}
STS_SMS_SERIES_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
3: 128_000,
4: 115_200,
5: 57_600,
6: 38_400,
7: 19_200,
}
SCS_SERIES_BAUDRATE_TABLE = { SCS_SERIES_BAUDRATE_TABLE = {
0: 1_000_000, 0: 1_000_000,
1: 500_000, 1: 500_000,
@@ -66,34 +135,52 @@ SCS_SERIES_BAUDRATE_TABLE = {
} }
MODEL_CONTROL_TABLE = { MODEL_CONTROL_TABLE = {
"sts_series": STS_SMS_SERIES_CONTROL_TABLE,
"scs_series": SCS_SERIES_CONTROL_TABLE, "scs_series": SCS_SERIES_CONTROL_TABLE,
"sts3215": SCS_SERIES_CONTROL_TABLE, "sms_series": STS_SMS_SERIES_CONTROL_TABLE,
"sts3215": STS_SMS_SERIES_CONTROL_TABLE,
"sts3250": STS_SMS_SERIES_CONTROL_TABLE,
"scs0009": SCS_SERIES_CONTROL_TABLE,
"sm8512bl": STS_SMS_SERIES_CONTROL_TABLE,
} }
MODEL_RESOLUTION = { MODEL_RESOLUTION = {
"scs_series": 4096, "sts_series": 4096,
"sms_series": 4096,
"scs_series": 1024,
"sts3215": 4096, "sts3215": 4096,
} "sts3250": 4096,
"sm8512bl": 4096,
# {model: model_number} "scs0009": 1024,
MODEL_NUMBER = {
"sts3215": 777,
} }
MODEL_BAUDRATE_TABLE = { MODEL_BAUDRATE_TABLE = {
"sts_series": STS_SMS_SERIES_BAUDRATE_TABLE,
"sms_series": STS_SMS_SERIES_BAUDRATE_TABLE,
"scs_series": SCS_SERIES_BAUDRATE_TABLE, "scs_series": SCS_SERIES_BAUDRATE_TABLE,
"sts3215": SCS_SERIES_BAUDRATE_TABLE, "sm8512bl": STS_SMS_SERIES_BAUDRATE_TABLE,
"sts3215": STS_SMS_SERIES_BAUDRATE_TABLE,
"sts3250": STS_SMS_SERIES_BAUDRATE_TABLE,
"scs0009": SCS_SERIES_BAUDRATE_TABLE,
} }
NORMALIZATION_REQUIRED = ["Goal_Position", "Present_Position"]
# Sign-Magnitude encoding bits # Sign-Magnitude encoding bits
ENCODINGS = { STS_SMS_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": 11, "Homing_Offset": 11,
"Goal_Speed": 15, "Goal_Speed": 15,
} }
AVAILABLE_BAUDRATES = [ MODEL_ENCODING_TABLE = {
"sts_series": STS_SMS_SERIES_ENCODINGS_TABLE,
"sms_series": STS_SMS_SERIES_ENCODINGS_TABLE,
"scs_series": {},
"sts3215": STS_SMS_SERIES_ENCODINGS_TABLE,
"sts3250": STS_SMS_SERIES_ENCODINGS_TABLE,
"sm8512bl": STS_SMS_SERIES_ENCODINGS_TABLE,
"scs0009": {},
}
SCAN_BAUDRATES = [
4_800, 4_800,
9_600, 9_600,
14_400, 14_400,
@@ -106,3 +193,11 @@ AVAILABLE_BAUDRATES = [
500_000, 500_000,
1_000_000, 1_000_000,
] ]
# {model: model_number} TODO
MODEL_NUMBER_TABLE = {
"sts3215": 777,
"sts3250": None,
"sm8512bl": None,
"scs0009": None,
}

View File

@@ -25,7 +25,7 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import cached_property from functools import cached_property
from pprint import pformat from pprint import pformat
from typing import Protocol, TypeAlias, overload from typing import Protocol, TypeAlias
import serial import serial
from deepdiff import DeepDiff from deepdiff import DeepDiff
@@ -257,9 +257,10 @@ class MotorsBus(abc.ABC):
default_timeout: int default_timeout: int
model_baudrate_table: dict[str, dict] model_baudrate_table: dict[str, dict]
model_ctrl_table: dict[str, dict] model_ctrl_table: dict[str, dict]
model_encoding_table: dict[str, dict]
model_number_table: dict[str, int] model_number_table: dict[str, int]
model_resolution_table: dict[str, int] model_resolution_table: dict[str, int]
normalization_required: list[str] normalized_data: list[str]
def __init__( def __init__(
self, self,
@@ -340,6 +341,24 @@ class MotorsBus(abc.ABC):
else: else:
raise TypeError(f"'{motor}' should be int, str.") raise TypeError(f"'{motor}' should be int, str.")
def _get_names_list(self, motors: str | list[str] | None) -> list[str]:
if motors is None:
return self.names
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors.copy()
else:
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int, float)):
return {id_: values for id_ in self.ids}
elif isinstance(values, dict):
return {self.motors[motor].id: val for motor, val in values.items()}
else:
raise TypeError(f"'values' is expected to be a single value or a dict. Got {values}")
def _validate_motors(self) -> None: def _validate_motors(self) -> None:
if len(self.ids) != len(set(self.ids)): if len(self.ids) != len(set(self.ids)):
raise ValueError(f"Some motors have the same id!\n{self}") raise ValueError(f"Some motors have the same id!\n{self}")
@@ -632,15 +651,11 @@ class MotorsBus(abc.ABC):
return unnormalized_values return unnormalized_values
@abc.abstractmethod @abc.abstractmethod
def _encode_value( def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
self, value: int, data_name: str | None = None, n_bytes: int | None = None
) -> dict[int, int]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _decode_value( def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
self, value: int, data_name: str | None = None, n_bytes: int | None = None
) -> dict[int, int]:
pass pass
@staticmethod @staticmethod
@@ -691,7 +706,7 @@ class MotorsBus(abc.ABC):
if not self._is_comm_success(comm): if not self._is_comm_success(comm):
if raise_on_error: if raise_on_error:
raise ConnectionError(self.packet_handler.getRxPacketError(comm)) raise ConnectionError(self.packet_handler.getTxRxResult(comm))
else: else:
return return
if self._is_error(error): if self._is_error(error):
@@ -708,87 +723,60 @@ class MotorsBus(abc.ABC):
) -> dict[int, list[int, str]] | None: ) -> dict[int, list[int, str]] | None:
pass pass
@overload
def sync_read(
self, data_name: str, motors: None = ..., *, normalize: bool = ..., num_retry: int = ...
) -> dict[str, Value]: ...
@overload
def sync_read( def sync_read(
self, self,
data_name: str, data_name: str,
motors: NameOrID | list[NameOrID], motors: str | list[str] | None = None,
*,
normalize: bool = ...,
num_retry: int = ...,
) -> dict[NameOrID, Value]: ...
def sync_read(
self,
data_name: str,
motors: NameOrID | list[NameOrID] | None = None,
*, *,
normalize: bool = True, normalize: bool = True,
num_retry: int = 0, num_retry: int = 0,
) -> dict[NameOrID, Value]: ) -> dict[str, Value]:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
) )
id_key_map: dict[int, NameOrID] = {} names = self._get_names_list(motors)
if motors is None: ids = [self.motors[name].id for name in names]
id_key_map = {m.id: name for name, m in self.motors.items()} models = [self.motors[name].model for name in names]
elif isinstance(motors, (str, int)):
id_key_map = {self._get_motor_id(motors): motors}
elif isinstance(motors, list):
id_key_map = {self._get_motor_id(m): m for m in motors}
else:
raise TypeError(motors)
motor_ids = list(id_key_map) if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name)
comm, ids_values = self._sync_read(data_name, motor_ids, num_retry=num_retry) model = next(iter(models))
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
comm, ids_values = self._sync_read(addr, n_bytes, ids, num_retry=num_retry)
if not self._is_comm_success(comm): if not self._is_comm_success(comm):
raise ConnectionError( raise ConnectionError(
f"Failed to sync read '{data_name}' on {motor_ids=} after {num_retry + 1} tries." f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
f"{self.packet_handler.getTxRxResult(comm)}" f"{self.packet_handler.getTxRxResult(comm)}"
) )
if normalize and data_name in self.normalization_required: ids_values = self._decode_sign(data_name, ids_values)
if normalize and data_name in self.normalized_data:
ids_values = self._normalize(data_name, ids_values) ids_values = self._normalize(data_name, ids_values)
return {id_key_map[id_]: val for id_, val in ids_values.items()} return {self._id_to_name(id_): value for id_, value in ids_values.items()}
def _sync_read( def _sync_read(
self, data_name: str, motor_ids: list[str], model: str | None = None, num_retry: int = 0 self, addr: int, n_bytes: int, motor_ids: list[int], num_retry: int = 0
) -> tuple[int, dict[int, int]]: ) -> tuple[int, dict[int, int]]:
if self._has_different_ctrl_tables:
models = [self._id_to_model(id_) for id_ in motor_ids]
assert_same_address(self.model_ctrl_table, models, data_name)
model = self._id_to_model(next(iter(motor_ids))) if model is None else model
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
self._setup_sync_reader(motor_ids, addr, n_bytes) self._setup_sync_reader(motor_ids, addr, n_bytes)
# FIXME(aliberts, pkooij): We should probably not have to do this.
# Let's try to see if we can do with better comm status handling instead.
# self.port_handler.ser.reset_output_buffer()
# self.port_handler.ser.reset_input_buffer()
for n_try in range(1 + num_retry): for n_try in range(1 + num_retry):
comm = self.sync_reader.txRxPacket() comm = self.sync_reader.txRxPacket()
if self._is_comm_success(comm): if self._is_comm_success(comm):
break break
logger.debug(f"Failed to sync read '{data_name}' ({addr=} {n_bytes=}) on {motor_ids=} ({n_try=})") logger.debug(
logger.debug(self.packet_handler.getRxPacketError(comm)) f"Failed to sync read @{addr=} ({n_bytes=}) on {motor_ids=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
values = {} )
for id_ in motor_ids:
val = self.sync_reader.getData(id_, addr, n_bytes)
values[id_] = self._decode_value(val, data_name, n_bytes)
values = {id_: self.sync_reader.getData(id_, addr, n_bytes) for id_ in motor_ids}
return comm, values return comm, values
def _setup_sync_reader(self, motor_ids: list[str], addr: int, n_bytes: int) -> None: def _setup_sync_reader(self, motor_ids: list[int], addr: int, n_bytes: int) -> None:
self.sync_reader.clearParam() self.sync_reader.clearParam()
self.sync_reader.start_address = addr self.sync_reader.start_address = addr
self.sync_reader.data_length = n_bytes self.sync_reader.data_length = n_bytes
@@ -799,7 +787,7 @@ class MotorsBus(abc.ABC):
# Would have to handle the logic of checking if a packet has been sent previously though but doable. # Would have to handle the logic of checking if a packet has been sent previously though but doable.
# This could be at the cost of increase latency between the moment the data is produced by the motors and # This could be at the cost of increase latency between the moment the data is produced by the motors and
# the moment it is used by a policy. # the moment it is used by a policy.
# def _async_read(self, motor_ids: list[str], address: int, n_bytes: int): # def _async_read(self, motor_ids: list[int], address: int, n_bytes: int):
# if self.sync_reader.start_address != address or self.sync_reader.data_length != n_bytes or ...: # if self.sync_reader.start_address != address or self.sync_reader.data_length != n_bytes or ...:
# self._setup_sync_reader(motor_ids, address, n_bytes) # self._setup_sync_reader(motor_ids, address, n_bytes)
# else: # else:
@@ -812,7 +800,7 @@ class MotorsBus(abc.ABC):
def sync_write( def sync_write(
self, self,
data_name: str, data_name: str,
values: Value | dict[NameOrID, Value], values: Value | dict[str, Value],
*, *,
normalize: bool = True, normalize: bool = True,
num_retry: int = 0, num_retry: int = 0,
@@ -822,41 +810,36 @@ class MotorsBus(abc.ABC):
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
) )
if isinstance(values, int): ids_values = self._get_ids_values_dict(values)
ids_values = {id_: values for id_ in self.ids} models = [self._id_to_model(id_) for id_ in ids_values]
elif isinstance(values, dict): if self._has_different_ctrl_tables:
ids_values = {self._get_motor_id(motor): val for motor, val in values.items()} assert_same_address(self.model_ctrl_table, models, data_name)
else:
raise TypeError(f"'values' is expected to be a single value or a dict. Got {values}")
if normalize and data_name in self.normalization_required and self.calibration is not None: model = next(iter(models))
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
if normalize and data_name in self.normalized_data:
ids_values = self._unnormalize(data_name, ids_values) ids_values = self._unnormalize(data_name, ids_values)
comm = self._sync_write(data_name, ids_values, num_retry=num_retry) ids_values = self._encode_sign(data_name, ids_values)
comm = self._sync_write(addr, n_bytes, ids_values, num_retry=num_retry)
if not self._is_comm_success(comm): if not self._is_comm_success(comm):
raise ConnectionError( raise ConnectionError(
f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
f"\n{self.packet_handler.getTxRxResult(comm)}" f"\n{self.packet_handler.getTxRxResult(comm)}"
) )
def _sync_write(self, data_name: str, ids_values: dict[int, int], num_retry: int = 0) -> int: def _sync_write(self, addr: int, n_bytes: int, ids_values: dict[int, int], num_retry: int = 0) -> int:
if self._has_different_ctrl_tables:
models = [self._id_to_model(id_) for id_ in ids_values]
assert_same_address(self.model_ctrl_table, models, data_name)
model = self._id_to_model(next(iter(ids_values)))
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
ids_values = {id_: self._encode_value(value, data_name, n_bytes) for id_, value in ids_values.items()}
self._setup_sync_writer(ids_values, addr, n_bytes) self._setup_sync_writer(ids_values, addr, n_bytes)
for n_try in range(1 + num_retry): for n_try in range(1 + num_retry):
comm = self.sync_writer.txPacket() comm = self.sync_writer.txPacket()
if self._is_comm_success(comm): if self._is_comm_success(comm):
break break
logger.debug( logger.debug(
f"Failed to sync write '{data_name}' ({addr=} {n_bytes=}) with {ids_values=} ({n_try=})" f"Failed to sync write @{addr=} ({n_bytes=}) with {ids_values=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
) )
logger.debug(self.packet_handler.getRxPacketError(comm))
return comm return comm
@@ -869,20 +852,23 @@ class MotorsBus(abc.ABC):
self.sync_writer.addParam(id_, data) self.sync_writer.addParam(id_, data)
def write( def write(
self, data_name: str, motor: NameOrID, value: Value, *, normalize: bool = True, num_retry: int = 0 self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
) -> None: ) -> None:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
) )
id_ = self._get_motor_id(motor) id_ = self.motors[motor].id
model = self.motors[motor].model
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
if normalize and data_name in self.normalization_required and self.calibration is not None: if normalize and data_name in self.normalized_data:
id_value = self._unnormalize(data_name, {id_: value}) value = self._unnormalize(data_name, {id_: value})[id_]
value = id_value[id_]
comm, error = self._write(data_name, id_, value, num_retry=num_retry) value = self._encode_sign(data_name, {id_: value})[id_]
comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry)
if not self._is_comm_success(comm): if not self._is_comm_success(comm):
raise ConnectionError( raise ConnectionError(
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
@@ -894,20 +880,18 @@ class MotorsBus(abc.ABC):
f"\n{self.packet_handler.getRxPacketError(error)}" f"\n{self.packet_handler.getRxPacketError(error)}"
) )
def _write(self, data_name: str, motor_id: int, value: int, num_retry: int = 0) -> tuple[int, int]: def _write(
model = self._id_to_model(motor_id) self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) ) -> tuple[int, int]:
value = self._encode_value(value, data_name, n_bytes)
data = self._split_int_to_bytes(value, n_bytes) data = self._split_int_to_bytes(value, n_bytes)
for n_try in range(1 + num_retry): for n_try in range(1 + num_retry):
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data)
if self._is_comm_success(comm): if self._is_comm_success(comm):
break break
logger.debug( logger.debug(
f"Failed to write '{data_name}' ({addr=} {n_bytes=}) on {motor_id=} with '{value}' ({n_try=})" f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
) )
logger.debug(self.packet_handler.getRxPacketError(comm))
return comm, error return comm, error

View File

@@ -5,7 +5,7 @@ import scservo_sdk as scs
import serial import serial
from mock_serial import MockSerial from mock_serial import MockSerial
from lerobot.common.motors.feetech import SCS_SERIES_CONTROL_TABLE, FeetechMotorsBus from lerobot.common.motors.feetech import STS_SMS_SERIES_CONTROL_TABLE, FeetechMotorsBus
from lerobot.common.motors.feetech.feetech import patch_setPacketTimeout from lerobot.common.motors.feetech.feetech import patch_setPacketTimeout
from .mock_serial_patch import WaitableStub from .mock_serial_patch import WaitableStub
@@ -297,7 +297,7 @@ class MockMotors(MockSerial):
instruction packets. It is meant to test MotorsBus classes. instruction packets. It is meant to test MotorsBus classes.
""" """
ctrl_table = SCS_SERIES_CONTROL_TABLE ctrl_table = STS_SMS_SERIES_CONTROL_TABLE
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -338,12 +338,6 @@ class MockMotors(MockSerial):
def build_read_stub( def build_read_stub(
self, data_name: str, scs_id: int, value: int | None = None, num_invalid_try: int = 0 self, data_name: str, scs_id: int, value: int | None = None, num_invalid_try: int = 0
) -> str: ) -> str:
"""
'data_name' supported:
- Model_Number
"""
if data_name != "Model_Number":
raise NotImplementedError
address, length = self.ctrl_table[data_name] address, length = self.ctrl_table[data_name]
read_request = MockInstructionPacket.read(scs_id, address, length) read_request = MockInstructionPacket.read(scs_id, address, length)
return_packet = MockStatusPacket.read(scs_id, value, length) return_packet = MockStatusPacket.read(scs_id, value, length)

View File

@@ -6,7 +6,7 @@ import pytest
import scservo_sdk as scs import scservo_sdk as scs
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.common.motors.feetech import MODEL_NUMBER, FeetechMotorsBus from lerobot.common.motors.feetech import MODEL_NUMBER_TABLE, FeetechMotorsBus
from lerobot.common.utils.encoding_utils import encode_sign_magnitude from lerobot.common.utils.encoding_utils import encode_sign_magnitude
from tests.mocks.mock_feetech import MockMotors, MockPortHandler from tests.mocks.mock_feetech import MockMotors, MockPortHandler
@@ -129,7 +129,7 @@ def test_scan_port(mock_motors):
@pytest.mark.parametrize("id_", [1, 2, 3]) @pytest.mark.parametrize("id_", [1, 2, 3])
def test_ping(id_, mock_motors, dummy_motors): def test_ping(id_, mock_motors, dummy_motors):
expected_model_nb = MODEL_NUMBER[dummy_motors[f"dummy_{id_}"].model] expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
ping_stub = mock_motors.build_ping_stub(id_) ping_stub = mock_motors.build_ping_stub(id_)
mobel_nb_stub = mock_motors.build_read_stub("Model_Number", id_, expected_model_nb) mobel_nb_stub = mock_motors.build_read_stub("Model_Number", id_, expected_model_nb)
motors_bus = FeetechMotorsBus( motors_bus = FeetechMotorsBus(
@@ -147,7 +147,7 @@ def test_ping(id_, mock_motors, dummy_motors):
def test_broadcast_ping(mock_motors, dummy_motors): def test_broadcast_ping(mock_motors, dummy_motors):
models = {m.id: m.model for m in dummy_motors.values()} models = {m.id: m.model for m in dummy_motors.values()}
expected_model_nbs = {id_: MODEL_NUMBER[model] for id_, model in models.items()} expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
ping_stub = mock_motors.build_broadcast_ping_stub(list(models)) ping_stub = mock_motors.build_broadcast_ping_stub(list(models))
mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs) mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs)
motors_bus = FeetechMotorsBus( motors_bus = FeetechMotorsBus(
@@ -191,55 +191,7 @@ def test_sync_read_none(mock_motors, dummy_motors):
(3, 4016), (3, 4016),
], ],
) )
def test_sync_read_by_id(id_, position, mock_motors, dummy_motors): def test_sync_read_single_value(id_, position, mock_motors, dummy_motors):
expected_position = {id_: position}
stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_position)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
read_position = motors_bus.sync_read("Present_Position", id_, normalize=False)
assert mock_motors.stubs[stub_name].called
assert read_position == expected_position
@pytest.mark.parametrize(
"ids, positions",
[
([1], [1337]),
([1, 2], [1337, 42]),
([1, 2, 3], [1337, 42, 4016]),
],
ids=["1 motor", "2 motors", "3 motors"],
) # fmt: skip
def test_sync_read_by_ids(ids, positions, mock_motors, dummy_motors):
assert len(ids) == len(positions)
expected_positions = dict(zip(ids, positions, strict=True))
stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_positions)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
read_positions = motors_bus.sync_read("Present_Position", ids, normalize=False)
assert mock_motors.stubs[stub_name].called
assert read_positions == expected_positions
@pytest.mark.parametrize(
"id_, position",
[
(1, 1337),
(2, 42),
(3, 4016),
],
)
def test_sync_read_by_name(id_, position, mock_motors, dummy_motors):
expected_position = {f"dummy_{id_}": position} expected_position = {f"dummy_{id_}": position}
stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position}) stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position})
motors_bus = FeetechMotorsBus( motors_bus = FeetechMotorsBus(
@@ -263,7 +215,7 @@ def test_sync_read_by_name(id_, position, mock_motors, dummy_motors):
], ],
ids=["1 motor", "2 motors", "3 motors"], ids=["1 motor", "2 motors", "3 motors"],
) # fmt: skip ) # fmt: skip
def test_sync_read_by_names(ids, positions, mock_motors, dummy_motors): def test_sync_read(ids, positions, mock_motors, dummy_motors):
assert len(ids) == len(positions) assert len(ids) == len(positions)
names = [f"dummy_{dxl_id}" for dxl_id in ids] names = [f"dummy_{dxl_id}" for dxl_id in ids]
expected_positions = dict(zip(names, positions, strict=True)) expected_positions = dict(zip(names, positions, strict=True))
@@ -291,9 +243,9 @@ def test_sync_read_by_names(ids, positions, mock_motors, dummy_motors):
], ],
) )
def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors): def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors):
expected_position = {1: pos} expected_position = {"dummy_1": pos}
stub_name = mock_motors.build_sync_read_stub( stub_name = mock_motors.build_sync_read_stub(
"Present_Position", expected_position, num_invalid_try=num_invalid_try "Present_Position", {1: pos}, num_invalid_try=num_invalid_try
) )
motors_bus = FeetechMotorsBus( motors_bus = FeetechMotorsBus(
port=mock_motors.port, port=mock_motors.port,
@@ -302,11 +254,11 @@ def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy
motors_bus.connect(assert_motors_exist=False) motors_bus.connect(assert_motors_exist=False)
if num_retry >= num_invalid_try: if num_retry >= num_invalid_try:
pos_dict = motors_bus.sync_read("Present_Position", 1, normalize=False, num_retry=num_retry) pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
assert pos_dict == {1: pos} assert pos_dict == expected_position
else: else:
with pytest.raises(ConnectionError): with pytest.raises(ConnectionError):
_ = motors_bus.sync_read("Present_Position", 1, normalize=False, num_retry=num_retry) _ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
expected_calls = min(1 + num_retry, 1 + num_invalid_try) expected_calls = min(1 + num_retry, 1 + num_invalid_try)
assert mock_motors.stubs[stub_name].calls == expected_calls assert mock_motors.stubs[stub_name].calls == expected_calls
@@ -335,28 +287,6 @@ def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors):
assert mock_motors.stubs[stub_name].wait_called() assert mock_motors.stubs[stub_name].wait_called()
@pytest.mark.parametrize(
"id_, position",
[
(1, 1337),
(2, 42),
(3, 4016),
],
)
def test_sync_write_by_id(id_, position, mock_motors, dummy_motors):
value = {id_: position}
stub_name = mock_motors.build_sync_write_stub("Goal_Position", value)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
motors_bus.sync_write("Goal_Position", value, normalize=False)
assert mock_motors.stubs[stub_name].wait_called()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ids, positions", "ids, positions",
[ [
@@ -366,54 +296,7 @@ def test_sync_write_by_id(id_, position, mock_motors, dummy_motors):
], ],
ids=["1 motor", "2 motors", "3 motors"], ids=["1 motor", "2 motors", "3 motors"],
) # fmt: skip ) # fmt: skip
def test_sync_write_by_ids(ids, positions, mock_motors, dummy_motors): def test_sync_write(ids, positions, mock_motors, dummy_motors):
assert len(ids) == len(positions)
values = dict(zip(ids, positions, strict=True))
stub_name = mock_motors.build_sync_write_stub("Goal_Position", values)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
motors_bus.sync_write("Goal_Position", values, normalize=False)
assert mock_motors.stubs[stub_name].wait_called()
@pytest.mark.parametrize(
"id_, position",
[
(1, 1337),
(2, 42),
(3, 4016),
],
)
def test_sync_write_by_name(id_, position, mock_motors, dummy_motors):
id_value = {id_: position}
stub_name = mock_motors.build_sync_write_stub("Goal_Position", id_value)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
write_value = {f"dummy_{id_}": position}
motors_bus.sync_write("Goal_Position", write_value, normalize=False)
assert mock_motors.stubs[stub_name].wait_called()
@pytest.mark.parametrize(
"ids, positions",
[
([1], [1337]),
([1, 2], [1337, 42]),
([1, 2, 3], [1337, 42, 4016]),
],
ids=["1 motor", "2 motors", "3 motors"],
) # fmt: skip
def test_sync_write_by_names(ids, positions, mock_motors, dummy_motors):
assert len(ids) == len(positions) assert len(ids) == len(positions)
ids_values = dict(zip(ids, positions, strict=True)) ids_values = dict(zip(ids, positions, strict=True))
stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values) stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values)
@@ -438,29 +321,7 @@ def test_sync_write_by_names(ids, positions, mock_motors, dummy_motors):
("Goal_Position", 3, 42), ("Goal_Position", 3, 42),
], ],
) )
def test_write_by_id(data_name, dxl_id, value, mock_motors, dummy_motors): def test_write(data_name, dxl_id, value, mock_motors, dummy_motors):
stub_name = mock_motors.build_write_stub(data_name, dxl_id, value)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
motors_bus.write(data_name, dxl_id, value, normalize=False)
assert mock_motors.stubs[stub_name].called
@pytest.mark.parametrize(
"data_name, dxl_id, value",
[
("Torque_Enable", 1, 0),
("Torque_Enable", 1, 1),
("Goal_Position", 2, 1337),
("Goal_Position", 3, 42),
],
)
def test_write_by_name(data_name, dxl_id, value, mock_motors, dummy_motors):
stub_name = mock_motors.build_write_stub(data_name, dxl_id, value) stub_name = mock_motors.build_write_stub(data_name, dxl_id, value)
motors_bus = FeetechMotorsBus( motors_bus = FeetechMotorsBus(
port=mock_motors.port, port=mock_motors.port,