Add more segmented tests (base motor bus & feetech), add feetech protocol 1 support

This commit is contained in:
Simon Alibert
2025-04-14 11:56:53 +02:00
parent e0b292ab51
commit bdbca09cb2
6 changed files with 749 additions and 350 deletions

View File

@@ -21,12 +21,13 @@ from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_si
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
from .tables import (
FIRMWARE_VERSION,
FIRMWARE_MAJOR_VERSION,
MODEL_BAUDRATE_TABLE,
MODEL_CONTROL_TABLE,
MODEL_ENCODING_TABLE,
MODEL_NUMBER,
MODEL_NUMBER_TABLE,
MODEL_PROTOCOL,
MODEL_RESOLUTION,
SCAN_BAUDRATES,
)
@@ -117,9 +118,10 @@ class FeetechMotorsBus(MotorsBus):
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
):
super().__init__(port, motors, calibration)
self.protocol_version = protocol_version
self._assert_same_protocol()
import scservo_sdk as scs
self.protocol_version = protocol_version
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
@@ -131,10 +133,21 @@ class FeetechMotorsBus(MotorsBus):
self._comm_success = scs.COMM_SUCCESS
self._no_error = 0x00
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
raise ValueError(f"Some motors are incompatible with protocol_version={self.protocol_version}")
def _assert_same_protocol(self) -> None:
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
raise RuntimeError("Some motors use an incompatible protocol.")
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
if instruction_name == "sync_read" and self.protocol_version == 1:
raise NotImplementedError(
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' instead."
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' sequentially instead."
)
if instruction_name == "broadcast_ping" and self.protocol_version == 1:
raise NotImplementedError(
"'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead."
)
def configure_motors(self) -> None:
@@ -157,12 +170,12 @@ class FeetechMotorsBus(MotorsBus):
return half_turn_homings
def disable_torque(self, motors: str | list[str] | None = None) -> None:
for name in self._get_names_list(motors):
for name in self._get_motors_list(motors):
self.write("Torque_Enable", name, TorqueMode.DISABLED.value)
self.write("Lock", name, 0)
def enable_torque(self, motors: str | list[str] | None = None) -> None:
for name in self._get_names_list(motors):
for name in self._get_motors_list(motors):
self.write("Torque_Enable", name, TorqueMode.ENABLED.value)
self.write("Lock", name, 1)
@@ -286,56 +299,52 @@ class FeetechMotorsBus(MotorsBus):
rx_length = rx_length - idx
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
if self.protocol_version == 0:
for n_try in range(1 + num_retry):
ids_status, comm = self._broadcast_ping_p0()
if self._is_comm_success(comm):
break
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
logger.debug(self.packet_handler.getTxRxResult(comm))
self._assert_protocol_is_compatible("broadcast_ping")
for n_try in range(1 + num_retry):
ids_status, comm = self._broadcast_ping_p0()
if self._is_comm_success(comm):
break
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
logger.debug(self.packet_handler.getTxRxResult(comm))
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:
display_dict = {
id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()
}
logger.error(
f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}"
)
return self._get_model_number(list(ids_status), raise_on_error)
else:
return self._broadcast_ping_p1(num_retry=num_retry)
def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
comm, firmware_versions = self._sync_read(*FIRMWARE_VERSION, motor_ids)
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:
display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()}
logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}")
return self._get_model_number(list(ids_status), raise_on_error)
def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]:
firmware_versions = {}
for id_ in motor_ids:
firm_ver_major, comm, error = self._read(
*FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error
)
if not self._is_comm_success(comm) or self._is_error(error):
return
firm_ver_minor, comm, error = self._read(
*FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error
)
if not self._is_comm_success(comm) or self._is_error(error):
return
firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}"
return firmware_versions
def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
if self.protocol_version == 1:
model_numbers = {}
for id_ in motor_ids:
model_nb, comm, error = self._read(*MODEL_NUMBER, id_)
if self._is_comm_success(comm) and not self._is_error(error):
model_numbers[id_] = model_nb
elif raise_on_error:
raise Exception # FIX
else:
comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids)
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
model_numbers = {}
for id_ in motor_ids:
model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error)
if not self._is_comm_success(comm) or self._is_error(error):
return
model_numbers[id_] = model_nb
return model_numbers

View File

@@ -1,9 +1,5 @@
FIRMWARE_MAJOR_VERSION = (0, 1)
FIRMWARE_MINOR_VERSION = (1, 1)
MODEL_MAJOR_VERSION = (3, 1)
MODEL_MINOR_VERSION = (4, 1)
FIRMWARE_VERSION = (0, 2)
MODEL_NUMBER = (3, 2)
# See this link for STS3215 Memory Table:
@@ -11,12 +7,9 @@ MODEL_NUMBER = (3, 2)
# data_name: (address, size_byte)
STS_SMS_SERIES_CONTROL_TABLE = {
# EPROM
"Firmware_Version": FIRMWARE_VERSION, # read-only
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
"Model_Number": MODEL_NUMBER, # read-only
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
# "Model_Minor_Version": MODEL_MINOR_VERSION,
"ID": (5, 1),
"Baud_Rate": (6, 1),
"Return_Delay_Time": (7, 1),
@@ -68,12 +61,9 @@ STS_SMS_SERIES_CONTROL_TABLE = {
SCS_SERIES_CONTROL_TABLE = {
# EPROM
"Firmware_Version": FIRMWARE_VERSION, # read-only
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
"Model_Number": MODEL_NUMBER, # read-only
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
# "Model_Minor_Version": MODEL_MINOR_VERSION,
"ID": (5, 1),
"Baud_Rate": (6, 1),
"Return_Delay": (7, 1),
@@ -194,10 +184,19 @@ SCAN_BAUDRATES = [
1_000_000,
]
# {model: model_number} TODO
MODEL_NUMBER_TABLE = {
"sts3215": 777,
"sts3250": None,
"sts3250": 2825,
"sm8512bl": 11272,
"scs0009": 1284,
}
MODEL_PROTOCOL = {
"sts_series": 0,
"sms_series": 0,
"scs_series": 1,
"sts3215": 0,
"sts3250": 0,
"sm8512bl": 0,
"scs0009": 1,
}

View File

@@ -283,6 +283,8 @@ class MotorsBus(abc.ABC):
self._id_to_name_dict = {m.id: name for name, m in self.motors.items()}
self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()}
self._validate_motors()
def __len__(self):
return len(self.motors)
@@ -341,7 +343,7 @@ class MotorsBus(abc.ABC):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_names_list(self, motors: str | list[str] | None) -> list[str]:
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
if motors is None:
return self.names
elif isinstance(motors, str):
@@ -422,8 +424,8 @@ class MotorsBus(abc.ABC):
logger.debug(f"{self.__class__.__name__} connected.")
@classmethod
def scan_port(cls, port: str) -> dict[int, list[int]]:
bus = cls(port, {})
def scan_port(cls, port: str, *args, **kwargs) -> dict[int, list[int]]:
bus = cls(port, {}, *args, **kwargs)
try:
bus.port_handler.openPort()
except (FileNotFoundError, OSError, serial.SerialException) as e:
@@ -715,17 +717,8 @@ class MotorsBus(abc.ABC):
model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name)
value, comm, error = self._read(addr, length, id_, num_retry=num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
f"{self.packet_handler.getTxRxResult(comm)}"
)
elif self._is_error(error):
raise RuntimeError(
f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
f"\n{self.packet_handler.getRxPacketError(error)}"
)
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
id_value = self._decode_sign(data_name, {id_: value})
@@ -734,7 +727,16 @@ class MotorsBus(abc.ABC):
return id_value[id_]
def _read(self, address: int, length: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]:
def _read(
self,
address: int,
length: int,
motor_id: int,
*,
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int]:
if length == 1:
read_fn = self.packet_handler.read1ByteTxRx
elif length == 2:
@@ -753,6 +755,11 @@ class MotorsBus(abc.ABC):
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
elif self._is_error(error) and raise_on_error:
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
return value, comm, error
def write(
@@ -772,20 +779,19 @@ class MotorsBus(abc.ABC):
value = self._encode_sign(data_name, {id_: value})[id_]
comm, error = self._write(addr, length, id_, value, num_retry=num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
f"\n{self.packet_handler.getTxRxResult(comm)}"
)
elif self._is_error(error):
raise RuntimeError(
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
f"\n{self.packet_handler.getRxPacketError(error)}"
)
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write(
self, addr: int, length: int, motor_id: int, value: int, num_retry: int = 0
self,
addr: int,
length: int,
motor_id: int,
value: int,
*,
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int]:
data = self._serialize_data(value, length)
for n_try in range(1 + num_retry):
@@ -797,6 +803,11 @@ class MotorsBus(abc.ABC):
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
elif self._is_error(error) and raise_on_error:
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
return comm, error
def sync_read(
@@ -814,7 +825,7 @@ class MotorsBus(abc.ABC):
self._assert_protocol_is_compatible("sync_read")
names = self._get_names_list(motors)
names = self._get_motors_list(motors)
ids = [self.motors[name].id for name in names]
models = [self.motors[name].model for name in names]
@@ -824,12 +835,10 @@ class MotorsBus(abc.ABC):
model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name)
comm, ids_values = self._sync_read(addr, length, ids, num_retry=num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
f"{self.packet_handler.getTxRxResult(comm)}"
)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
ids_values = self._decode_sign(data_name, ids_values)
@@ -839,8 +848,15 @@ class MotorsBus(abc.ABC):
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
def _sync_read(
self, addr: int, length: int, motor_ids: list[int], num_retry: int = 0
) -> tuple[int, dict[int, int]]:
self,
addr: int,
length: int,
motor_ids: list[int],
*,
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[dict[int, int], int]:
self._setup_sync_reader(motor_ids, addr, length)
for n_try in range(1 + num_retry):
comm = self.sync_reader.txRxPacket()
@@ -851,8 +867,11 @@ class MotorsBus(abc.ABC):
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids}
return comm, values
return values, comm
def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None:
self.sync_reader.clearParam()
@@ -901,14 +920,18 @@ class MotorsBus(abc.ABC):
ids_values = self._encode_sign(data_name, ids_values)
comm = self._sync_write(addr, length, ids_values, num_retry=num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
f"\n{self.packet_handler.getTxRxResult(comm)}"
)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _sync_write(self, addr: int, length: int, ids_values: dict[int, int], num_retry: int = 0) -> int:
def _sync_write(
self,
addr: int,
length: int,
ids_values: dict[int, int],
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> int:
self._setup_sync_writer(ids_values, addr, length)
for n_try in range(1 + num_retry):
comm = self.sync_writer.txPacket()
@@ -919,6 +942,9 @@ class MotorsBus(abc.ABC):
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
return comm
def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None: