Add more segmented tests (base motor bus & feetech), add feetech protocol 1 support
This commit is contained in:
@@ -21,12 +21,13 @@ from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_si
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
|
||||
from .tables import (
|
||||
FIRMWARE_VERSION,
|
||||
FIRMWARE_MAJOR_VERSION,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
MODEL_CONTROL_TABLE,
|
||||
MODEL_ENCODING_TABLE,
|
||||
MODEL_NUMBER,
|
||||
MODEL_NUMBER_TABLE,
|
||||
MODEL_PROTOCOL,
|
||||
MODEL_RESOLUTION,
|
||||
SCAN_BAUDRATES,
|
||||
)
|
||||
@@ -117,9 +118,10 @@ class FeetechMotorsBus(MotorsBus):
|
||||
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
|
||||
):
|
||||
super().__init__(port, motors, calibration)
|
||||
self.protocol_version = protocol_version
|
||||
self._assert_same_protocol()
|
||||
import scservo_sdk as scs
|
||||
|
||||
self.protocol_version = protocol_version
|
||||
self.port_handler = scs.PortHandler(self.port)
|
||||
# HACK: monkeypatch
|
||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
|
||||
@@ -131,10 +133,21 @@ class FeetechMotorsBus(MotorsBus):
|
||||
self._comm_success = scs.COMM_SUCCESS
|
||||
self._no_error = 0x00
|
||||
|
||||
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
|
||||
raise ValueError(f"Some motors are incompatible with protocol_version={self.protocol_version}")
|
||||
|
||||
def _assert_same_protocol(self) -> None:
|
||||
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
|
||||
raise RuntimeError("Some motors use an incompatible protocol.")
|
||||
|
||||
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
|
||||
if instruction_name == "sync_read" and self.protocol_version == 1:
|
||||
raise NotImplementedError(
|
||||
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' instead."
|
||||
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' sequentially instead."
|
||||
)
|
||||
if instruction_name == "broadcast_ping" and self.protocol_version == 1:
|
||||
raise NotImplementedError(
|
||||
"'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead."
|
||||
)
|
||||
|
||||
def configure_motors(self) -> None:
|
||||
@@ -157,12 +170,12 @@ class FeetechMotorsBus(MotorsBus):
|
||||
return half_turn_homings
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None) -> None:
|
||||
for name in self._get_names_list(motors):
|
||||
for name in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", name, TorqueMode.DISABLED.value)
|
||||
self.write("Lock", name, 0)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None) -> None:
|
||||
for name in self._get_names_list(motors):
|
||||
for name in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", name, TorqueMode.ENABLED.value)
|
||||
self.write("Lock", name, 1)
|
||||
|
||||
@@ -286,56 +299,52 @@ class FeetechMotorsBus(MotorsBus):
|
||||
rx_length = rx_length - idx
|
||||
|
||||
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
|
||||
if self.protocol_version == 0:
|
||||
for n_try in range(1 + num_retry):
|
||||
ids_status, comm = self._broadcast_ping_p0()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
|
||||
logger.debug(self.packet_handler.getTxRxResult(comm))
|
||||
self._assert_protocol_is_compatible("broadcast_ping")
|
||||
for n_try in range(1 + num_retry):
|
||||
ids_status, comm = self._broadcast_ping_p0()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
|
||||
logger.debug(self.packet_handler.getTxRxResult(comm))
|
||||
|
||||
if not self._is_comm_success(comm):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
return
|
||||
|
||||
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
||||
if ids_errors:
|
||||
display_dict = {
|
||||
id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()
|
||||
}
|
||||
logger.error(
|
||||
f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}"
|
||||
)
|
||||
|
||||
return self._get_model_number(list(ids_status), raise_on_error)
|
||||
else:
|
||||
return self._broadcast_ping_p1(num_retry=num_retry)
|
||||
|
||||
def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
|
||||
comm, firmware_versions = self._sync_read(*FIRMWARE_VERSION, motor_ids)
|
||||
if not self._is_comm_success(comm):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
return
|
||||
|
||||
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
||||
if ids_errors:
|
||||
display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()}
|
||||
logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}")
|
||||
|
||||
return self._get_model_number(list(ids_status), raise_on_error)
|
||||
|
||||
def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]:
|
||||
firmware_versions = {}
|
||||
for id_ in motor_ids:
|
||||
firm_ver_major, comm, error = self._read(
|
||||
*FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error
|
||||
)
|
||||
if not self._is_comm_success(comm) or self._is_error(error):
|
||||
return
|
||||
|
||||
firm_ver_minor, comm, error = self._read(
|
||||
*FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error
|
||||
)
|
||||
if not self._is_comm_success(comm) or self._is_error(error):
|
||||
return
|
||||
|
||||
firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}"
|
||||
|
||||
return firmware_versions
|
||||
|
||||
def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
|
||||
if self.protocol_version == 1:
|
||||
model_numbers = {}
|
||||
for id_ in motor_ids:
|
||||
model_nb, comm, error = self._read(*MODEL_NUMBER, id_)
|
||||
if self._is_comm_success(comm) and not self._is_error(error):
|
||||
model_numbers[id_] = model_nb
|
||||
elif raise_on_error:
|
||||
raise Exception # FIX
|
||||
|
||||
else:
|
||||
comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids)
|
||||
if not self._is_comm_success(comm):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
model_numbers = {}
|
||||
for id_ in motor_ids:
|
||||
model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error)
|
||||
if not self._is_comm_success(comm) or self._is_error(error):
|
||||
return
|
||||
|
||||
model_numbers[id_] = model_nb
|
||||
|
||||
return model_numbers
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
FIRMWARE_MAJOR_VERSION = (0, 1)
|
||||
FIRMWARE_MINOR_VERSION = (1, 1)
|
||||
MODEL_MAJOR_VERSION = (3, 1)
|
||||
MODEL_MINOR_VERSION = (4, 1)
|
||||
|
||||
FIRMWARE_VERSION = (0, 2)
|
||||
MODEL_NUMBER = (3, 2)
|
||||
|
||||
# See this link for STS3215 Memory Table:
|
||||
@@ -11,12 +7,9 @@ MODEL_NUMBER = (3, 2)
|
||||
# data_name: (address, size_byte)
|
||||
STS_SMS_SERIES_CONTROL_TABLE = {
|
||||
# EPROM
|
||||
"Firmware_Version": FIRMWARE_VERSION, # read-only
|
||||
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
||||
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
||||
"Model_Number": MODEL_NUMBER, # read-only
|
||||
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
||||
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
||||
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
|
||||
# "Model_Minor_Version": MODEL_MINOR_VERSION,
|
||||
"ID": (5, 1),
|
||||
"Baud_Rate": (6, 1),
|
||||
"Return_Delay_Time": (7, 1),
|
||||
@@ -68,12 +61,9 @@ STS_SMS_SERIES_CONTROL_TABLE = {
|
||||
|
||||
SCS_SERIES_CONTROL_TABLE = {
|
||||
# EPROM
|
||||
"Firmware_Version": FIRMWARE_VERSION, # read-only
|
||||
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
||||
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
||||
"Model_Number": MODEL_NUMBER, # read-only
|
||||
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
||||
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
||||
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
|
||||
# "Model_Minor_Version": MODEL_MINOR_VERSION,
|
||||
"ID": (5, 1),
|
||||
"Baud_Rate": (6, 1),
|
||||
"Return_Delay": (7, 1),
|
||||
@@ -194,10 +184,19 @@ SCAN_BAUDRATES = [
|
||||
1_000_000,
|
||||
]
|
||||
|
||||
# {model: model_number} TODO
|
||||
MODEL_NUMBER_TABLE = {
|
||||
"sts3215": 777,
|
||||
"sts3250": None,
|
||||
"sts3250": 2825,
|
||||
"sm8512bl": 11272,
|
||||
"scs0009": 1284,
|
||||
}
|
||||
|
||||
MODEL_PROTOCOL = {
|
||||
"sts_series": 0,
|
||||
"sms_series": 0,
|
||||
"scs_series": 1,
|
||||
"sts3215": 0,
|
||||
"sts3250": 0,
|
||||
"sm8512bl": 0,
|
||||
"scs0009": 1,
|
||||
}
|
||||
|
||||
@@ -283,6 +283,8 @@ class MotorsBus(abc.ABC):
|
||||
self._id_to_name_dict = {m.id: name for name, m in self.motors.items()}
|
||||
self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()}
|
||||
|
||||
self._validate_motors()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.motors)
|
||||
|
||||
@@ -341,7 +343,7 @@ class MotorsBus(abc.ABC):
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_names_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
if motors is None:
|
||||
return self.names
|
||||
elif isinstance(motors, str):
|
||||
@@ -422,8 +424,8 @@ class MotorsBus(abc.ABC):
|
||||
logger.debug(f"{self.__class__.__name__} connected.")
|
||||
|
||||
@classmethod
|
||||
def scan_port(cls, port: str) -> dict[int, list[int]]:
|
||||
bus = cls(port, {})
|
||||
def scan_port(cls, port: str, *args, **kwargs) -> dict[int, list[int]]:
|
||||
bus = cls(port, {}, *args, **kwargs)
|
||||
try:
|
||||
bus.port_handler.openPort()
|
||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
||||
@@ -715,17 +717,8 @@ class MotorsBus(abc.ABC):
|
||||
model = self.motors[motor].model
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
value, comm, error = self._read(addr, length, id_, num_retry=num_retry)
|
||||
if not self._is_comm_success(comm):
|
||||
raise ConnectionError(
|
||||
f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
elif self._is_error(error):
|
||||
raise RuntimeError(
|
||||
f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
||||
f"\n{self.packet_handler.getRxPacketError(error)}"
|
||||
)
|
||||
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
||||
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
id_value = self._decode_sign(data_name, {id_: value})
|
||||
|
||||
@@ -734,7 +727,16 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
return id_value[id_]
|
||||
|
||||
def _read(self, address: int, length: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]:
|
||||
def _read(
|
||||
self,
|
||||
address: int,
|
||||
length: int,
|
||||
motor_id: int,
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> tuple[int, int]:
|
||||
if length == 1:
|
||||
read_fn = self.packet_handler.read1ByteTxRx
|
||||
elif length == 2:
|
||||
@@ -753,6 +755,11 @@ class MotorsBus(abc.ABC):
|
||||
+ self.packet_handler.getTxRxResult(comm)
|
||||
)
|
||||
|
||||
if not self._is_comm_success(comm) and raise_on_error:
|
||||
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||
elif self._is_error(error) and raise_on_error:
|
||||
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
|
||||
|
||||
return value, comm, error
|
||||
|
||||
def write(
|
||||
@@ -772,20 +779,19 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
value = self._encode_sign(data_name, {id_: value})[id_]
|
||||
|
||||
comm, error = self._write(addr, length, id_, value, num_retry=num_retry)
|
||||
if not self._is_comm_success(comm):
|
||||
raise ConnectionError(
|
||||
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
||||
f"\n{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
elif self._is_error(error):
|
||||
raise RuntimeError(
|
||||
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
||||
f"\n{self.packet_handler.getRxPacketError(error)}"
|
||||
)
|
||||
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
||||
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
def _write(
|
||||
self, addr: int, length: int, motor_id: int, value: int, num_retry: int = 0
|
||||
self,
|
||||
addr: int,
|
||||
length: int,
|
||||
motor_id: int,
|
||||
value: int,
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> tuple[int, int]:
|
||||
data = self._serialize_data(value, length)
|
||||
for n_try in range(1 + num_retry):
|
||||
@@ -797,6 +803,11 @@ class MotorsBus(abc.ABC):
|
||||
+ self.packet_handler.getTxRxResult(comm)
|
||||
)
|
||||
|
||||
if not self._is_comm_success(comm) and raise_on_error:
|
||||
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||
elif self._is_error(error) and raise_on_error:
|
||||
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
|
||||
|
||||
return comm, error
|
||||
|
||||
def sync_read(
|
||||
@@ -814,7 +825,7 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
self._assert_protocol_is_compatible("sync_read")
|
||||
|
||||
names = self._get_names_list(motors)
|
||||
names = self._get_motors_list(motors)
|
||||
ids = [self.motors[name].id for name in names]
|
||||
models = [self.motors[name].model for name in names]
|
||||
|
||||
@@ -824,12 +835,10 @@ class MotorsBus(abc.ABC):
|
||||
model = next(iter(models))
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
comm, ids_values = self._sync_read(addr, length, ids, num_retry=num_retry)
|
||||
if not self._is_comm_success(comm):
|
||||
raise ConnectionError(
|
||||
f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
|
||||
ids_values, _ = self._sync_read(
|
||||
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||
)
|
||||
|
||||
ids_values = self._decode_sign(data_name, ids_values)
|
||||
|
||||
@@ -839,8 +848,15 @@ class MotorsBus(abc.ABC):
|
||||
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
|
||||
|
||||
def _sync_read(
|
||||
self, addr: int, length: int, motor_ids: list[int], num_retry: int = 0
|
||||
) -> tuple[int, dict[int, int]]:
|
||||
self,
|
||||
addr: int,
|
||||
length: int,
|
||||
motor_ids: list[int],
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> tuple[dict[int, int], int]:
|
||||
self._setup_sync_reader(motor_ids, addr, length)
|
||||
for n_try in range(1 + num_retry):
|
||||
comm = self.sync_reader.txRxPacket()
|
||||
@@ -851,8 +867,11 @@ class MotorsBus(abc.ABC):
|
||||
+ self.packet_handler.getTxRxResult(comm)
|
||||
)
|
||||
|
||||
if not self._is_comm_success(comm) and raise_on_error:
|
||||
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||
|
||||
values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids}
|
||||
return comm, values
|
||||
return values, comm
|
||||
|
||||
def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None:
|
||||
self.sync_reader.clearParam()
|
||||
@@ -901,14 +920,18 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
ids_values = self._encode_sign(data_name, ids_values)
|
||||
|
||||
comm = self._sync_write(addr, length, ids_values, num_retry=num_retry)
|
||||
if not self._is_comm_success(comm):
|
||||
raise ConnectionError(
|
||||
f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
||||
f"\n{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
||||
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
def _sync_write(self, addr: int, length: int, ids_values: dict[int, int], num_retry: int = 0) -> int:
|
||||
def _sync_write(
|
||||
self,
|
||||
addr: int,
|
||||
length: int,
|
||||
ids_values: dict[int, int],
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> int:
|
||||
self._setup_sync_writer(ids_values, addr, length)
|
||||
for n_try in range(1 + num_retry):
|
||||
comm = self.sync_writer.txPacket()
|
||||
@@ -919,6 +942,9 @@ class MotorsBus(abc.ABC):
|
||||
+ self.packet_handler.getTxRxResult(comm)
|
||||
)
|
||||
|
||||
if not self._is_comm_success(comm) and raise_on_error:
|
||||
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||
|
||||
return comm
|
||||
|
||||
def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None:
|
||||
|
||||
Reference in New Issue
Block a user