Add support for feetech scs series + various fixes
This commit is contained in:
@@ -21,19 +21,22 @@ from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_si
|
|||||||
|
|
||||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
|
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
|
||||||
from .tables import (
|
from .tables import (
|
||||||
AVAILABLE_BAUDRATES,
|
FIRMWARE_VERSION,
|
||||||
ENCODINGS,
|
|
||||||
MODEL_BAUDRATE_TABLE,
|
MODEL_BAUDRATE_TABLE,
|
||||||
MODEL_CONTROL_TABLE,
|
MODEL_CONTROL_TABLE,
|
||||||
|
MODEL_ENCODING_TABLE,
|
||||||
MODEL_NUMBER,
|
MODEL_NUMBER,
|
||||||
|
MODEL_NUMBER_TABLE,
|
||||||
MODEL_RESOLUTION,
|
MODEL_RESOLUTION,
|
||||||
NORMALIZATION_REQUIRED,
|
SCAN_BAUDRATES,
|
||||||
)
|
)
|
||||||
|
|
||||||
PROTOCOL_VERSION = 0
|
PROTOCOL_VERSION = 0
|
||||||
BAUDRATE = 1_000_000
|
BAUDRATE = 1_000_000
|
||||||
DEFAULT_TIMEOUT_MS = 1000
|
DEFAULT_TIMEOUT_MS = 1000
|
||||||
|
|
||||||
|
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -80,16 +83,14 @@ class FeetechMotorsBus(MotorsBus):
|
|||||||
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
|
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
|
available_baudrates = deepcopy(SCAN_BAUDRATES)
|
||||||
default_timeout = DEFAULT_TIMEOUT_MS
|
default_timeout = DEFAULT_TIMEOUT_MS
|
||||||
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
|
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
|
||||||
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||||
model_number_table = deepcopy(MODEL_NUMBER)
|
model_encoding_table = deepcopy(MODEL_ENCODING_TABLE)
|
||||||
|
model_number_table = deepcopy(MODEL_NUMBER_TABLE)
|
||||||
model_resolution_table = deepcopy(MODEL_RESOLUTION)
|
model_resolution_table = deepcopy(MODEL_RESOLUTION)
|
||||||
normalization_required = deepcopy(NORMALIZATION_REQUIRED)
|
normalized_data = deepcopy(NORMALIZED_DATA)
|
||||||
|
|
||||||
# Feetech specific
|
|
||||||
encodings = deepcopy(ENCODINGS)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -140,13 +141,25 @@ class FeetechMotorsBus(MotorsBus):
|
|||||||
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value)
|
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value)
|
||||||
self.write("Lock", motor, 1)
|
self.write("Lock", motor, 1)
|
||||||
|
|
||||||
def _encode_value(self, value: int, data_name: str | None = None, n_bytes: int | None = None) -> int:
|
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||||
sign_bit = self.encodings.get(data_name)
|
for id_ in ids_values:
|
||||||
return encode_sign_magnitude(value, sign_bit) if sign_bit is not None else value
|
model = self._id_to_model(id_)
|
||||||
|
encoding_table = self.model_encoding_table.get(model)
|
||||||
|
if encoding_table and data_name in encoding_table:
|
||||||
|
sign_bit = encoding_table[data_name]
|
||||||
|
ids_values[id_] = encode_sign_magnitude(ids_values[id_], sign_bit)
|
||||||
|
|
||||||
def _decode_value(self, value: int, data_name: str | None = None, n_bytes: int | None = None) -> int:
|
return ids_values
|
||||||
sign_bit = self.encodings.get(data_name)
|
|
||||||
return decode_sign_magnitude(value, sign_bit) if sign_bit is not None else value
|
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||||
|
for id_ in ids_values:
|
||||||
|
model = self._id_to_model(id_)
|
||||||
|
encoding_table = self.model_encoding_table.get(model)
|
||||||
|
if encoding_table and data_name in encoding_table:
|
||||||
|
sign_bit = encoding_table[data_name]
|
||||||
|
ids_values[id_] = decode_sign_magnitude(ids_values[id_], sign_bit)
|
||||||
|
|
||||||
|
return ids_values
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]:
|
def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]:
|
||||||
@@ -220,15 +233,15 @@ class FeetechMotorsBus(MotorsBus):
|
|||||||
return data_list, scs.COMM_RX_CORRUPT
|
return data_list, scs.COMM_RX_CORRUPT
|
||||||
|
|
||||||
# find packet header
|
# find packet header
|
||||||
for id_ in range(0, (rx_length - 1)):
|
for idx in range(0, (rx_length - 1)):
|
||||||
if (rxpacket[id_] == 0xFF) and (rxpacket[id_ + 1] == 0xFF):
|
if (rxpacket[idx] == 0xFF) and (rxpacket[idx + 1] == 0xFF):
|
||||||
break
|
break
|
||||||
|
|
||||||
if id_ == 0: # found at the beginning of the packet
|
if idx == 0: # found at the beginning of the packet
|
||||||
# calculate checksum
|
# calculate checksum
|
||||||
checksum = 0
|
checksum = 0
|
||||||
for id_ in range(2, status_length - 1): # except header & checksum
|
for idx in range(2, status_length - 1): # except header & checksum
|
||||||
checksum += rxpacket[id_]
|
checksum += rxpacket[idx]
|
||||||
|
|
||||||
checksum = scs.SCS_LOBYTE(~checksum)
|
checksum = scs.SCS_LOBYTE(~checksum)
|
||||||
if rxpacket[status_length - 1] == checksum:
|
if rxpacket[status_length - 1] == checksum:
|
||||||
@@ -247,34 +260,71 @@ class FeetechMotorsBus(MotorsBus):
|
|||||||
rx_length = rx_length - 2
|
rx_length = rx_length - 2
|
||||||
else:
|
else:
|
||||||
# remove unnecessary packets
|
# remove unnecessary packets
|
||||||
del rxpacket[0:id_]
|
del rxpacket[0:idx]
|
||||||
rx_length = rx_length - id_
|
rx_length = rx_length - idx
|
||||||
|
|
||||||
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
|
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
|
||||||
for n_try in range(1 + num_retry):
|
for n_try in range(1 + num_retry):
|
||||||
ids_status, comm = self._broadcast_ping()
|
ids_status, comm = self._broadcast_ping()
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
logger.debug(f"Broadcast failed on port '{self.port}' ({n_try=})")
|
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
|
||||||
logger.debug(self.packet_handler.getRxPacketError(comm))
|
logger.debug(self.packet_handler.getTxRxResult(comm))
|
||||||
|
|
||||||
if not self._is_comm_success(comm):
|
if not self._is_comm_success(comm):
|
||||||
if raise_on_error:
|
if raise_on_error:
|
||||||
raise ConnectionError(self.packet_handler.getRxPacketError(comm))
|
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||||
|
return
|
||||||
return ids_status if ids_status else None
|
|
||||||
|
|
||||||
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
||||||
if ids_errors:
|
if ids_errors:
|
||||||
display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()}
|
display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()}
|
||||||
logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}")
|
logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}")
|
||||||
comm, model_numbers = self._sync_read(
|
|
||||||
"Model_Number", list(ids_status), model="scs_series", num_retry=num_retry
|
return self._get_model_number(list(ids_status), raise_on_error)
|
||||||
)
|
|
||||||
|
def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
|
||||||
|
# comm, major = self._sync_read(*FIRMWARE_MAJOR_VERSION, motor_ids)
|
||||||
|
# if not self._is_comm_success(comm):
|
||||||
|
# if raise_on_error:
|
||||||
|
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||||
|
# return
|
||||||
|
|
||||||
|
# comm, minor = self._sync_read(*FIRMWARE_MINOR_VERSION, motor_ids)
|
||||||
|
# if not self._is_comm_success(comm):
|
||||||
|
# if raise_on_error:
|
||||||
|
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||||
|
# return
|
||||||
|
|
||||||
|
# return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids}
|
||||||
|
|
||||||
|
comm, firmware_versions = self._sync_read(*FIRMWARE_VERSION, motor_ids)
|
||||||
if not self._is_comm_success(comm):
|
if not self._is_comm_success(comm):
|
||||||
if raise_on_error:
|
if raise_on_error:
|
||||||
raise ConnectionError(self.packet_handler.getRxPacketError(comm))
|
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||||
|
return
|
||||||
|
|
||||||
return model_numbers if model_numbers else None
|
return firmware_versions
|
||||||
|
|
||||||
|
def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
|
||||||
|
# comm, major = self._sync_read(*MODEL_MAJOR_VERSION, motor_ids)
|
||||||
|
# if not self._is_comm_success(comm):
|
||||||
|
# if raise_on_error:
|
||||||
|
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||||
|
# return
|
||||||
|
|
||||||
|
# comm, minor = self._sync_read(*MODEL_MINOR_VERSION, motor_ids)
|
||||||
|
# if not self._is_comm_success(comm):
|
||||||
|
# if raise_on_error:
|
||||||
|
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||||
|
# return
|
||||||
|
|
||||||
|
# return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids}
|
||||||
|
|
||||||
|
comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids)
|
||||||
|
if not self._is_comm_success(comm):
|
||||||
|
if raise_on_error:
|
||||||
|
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||||
|
return
|
||||||
|
|
||||||
return model_numbers
|
return model_numbers
|
||||||
|
|||||||
@@ -1,10 +1,22 @@
|
|||||||
|
FIRMWARE_MAJOR_VERSION = (0, 1)
|
||||||
|
FIRMWARE_MINOR_VERSION = (1, 1)
|
||||||
|
MODEL_MAJOR_VERSION = (3, 1)
|
||||||
|
MODEL_MINOR_VERSION = (4, 1)
|
||||||
|
|
||||||
|
FIRMWARE_VERSION = (0, 2)
|
||||||
|
MODEL_NUMBER = (3, 2)
|
||||||
|
|
||||||
# See this link for STS3215 Memory Table:
|
# See this link for STS3215 Memory Table:
|
||||||
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
|
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
|
||||||
# data_name: (address, size_byte)
|
# data_name: (address, size_byte)
|
||||||
SCS_SERIES_CONTROL_TABLE = {
|
STS_SMS_SERIES_CONTROL_TABLE = {
|
||||||
# EPROM
|
# EPROM
|
||||||
"Firmware_Version": (0, 2),
|
"Firmware_Version": FIRMWARE_VERSION, # read-only
|
||||||
"Model_Number": (3, 2),
|
"Model_Number": MODEL_NUMBER, # read-only
|
||||||
|
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
||||||
|
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
||||||
|
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
|
||||||
|
# "Model_Minor_Version": MODEL_MINOR_VERSION,
|
||||||
"ID": (5, 1),
|
"ID": (5, 1),
|
||||||
"Baud_Rate": (6, 1),
|
"Baud_Rate": (6, 1),
|
||||||
"Return_Delay_Time": (7, 1),
|
"Return_Delay_Time": (7, 1),
|
||||||
@@ -42,18 +54,75 @@ SCS_SERIES_CONTROL_TABLE = {
|
|||||||
"Goal_Speed": (46, 2),
|
"Goal_Speed": (46, 2),
|
||||||
"Torque_Limit": (48, 2),
|
"Torque_Limit": (48, 2),
|
||||||
"Lock": (55, 1),
|
"Lock": (55, 1),
|
||||||
"Present_Position": (56, 2),
|
"Present_Position": (56, 2), # read-only
|
||||||
"Present_Speed": (58, 2),
|
"Present_Speed": (58, 2), # read-only
|
||||||
"Present_Load": (60, 2),
|
"Present_Load": (60, 2), # read-only
|
||||||
"Present_Voltage": (62, 1),
|
"Present_Voltage": (62, 1), # read-only
|
||||||
"Present_Temperature": (63, 1),
|
"Present_Temperature": (63, 1), # read-only
|
||||||
"Status": (65, 1),
|
"Status": (65, 1), # read-only
|
||||||
"Moving": (66, 1),
|
"Moving": (66, 1), # read-only
|
||||||
"Present_Current": (69, 2),
|
"Present_Current": (69, 2), # read-only
|
||||||
# Not in the Memory Table
|
# Not in the Memory Table
|
||||||
"Maximum_Acceleration": (85, 2),
|
"Maximum_Acceleration": (85, 2),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SCS_SERIES_CONTROL_TABLE = {
|
||||||
|
# EPROM
|
||||||
|
"Firmware_Version": FIRMWARE_VERSION, # read-only
|
||||||
|
"Model_Number": MODEL_NUMBER, # read-only
|
||||||
|
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
||||||
|
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
||||||
|
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
|
||||||
|
# "Model_Minor_Version": MODEL_MINOR_VERSION,
|
||||||
|
"ID": (5, 1),
|
||||||
|
"Baud_Rate": (6, 1),
|
||||||
|
"Return_Delay": (7, 1),
|
||||||
|
"Response_Status_Level": (8, 1),
|
||||||
|
"Min_Position_Limit": (9, 2),
|
||||||
|
"Max_Position_Limit": (11, 2),
|
||||||
|
"Max_Temperature_Limit": (13, 1),
|
||||||
|
"Max_Voltage_Limit": (14, 1),
|
||||||
|
"Min_Voltage_Limit": (15, 1),
|
||||||
|
"Max_Torque_Limit": (16, 2),
|
||||||
|
"Phase": (18, 1),
|
||||||
|
"Unloading_Condition": (19, 1),
|
||||||
|
"LED_Alarm_Condition": (20, 1),
|
||||||
|
"P_Coefficient": (21, 1),
|
||||||
|
"D_Coefficient": (22, 1),
|
||||||
|
"I_Coefficient": (23, 1),
|
||||||
|
"Minimum_Startup_Force": (24, 2),
|
||||||
|
"CW_Dead_Zone": (26, 1),
|
||||||
|
"CCW_Dead_Zone": (27, 1),
|
||||||
|
"Protective_Torque": (37, 1),
|
||||||
|
"Protection_Time": (38, 1),
|
||||||
|
# SRAM
|
||||||
|
"Torque_Enable": (40, 1),
|
||||||
|
"Acceleration": (41, 1),
|
||||||
|
"Goal_Position": (42, 2),
|
||||||
|
"Running_Time": (44, 2),
|
||||||
|
"Goal_Speed": (46, 2),
|
||||||
|
"Lock": (48, 1),
|
||||||
|
"Present_Position": (56, 2), # read-only
|
||||||
|
"Present_Speed": (58, 2), # read-only
|
||||||
|
"Present_Load": (60, 2), # read-only
|
||||||
|
"Present_Voltage": (62, 1), # read-only
|
||||||
|
"Present_Temperature": (63, 1), # read-only
|
||||||
|
"Sync_Write_Flag": (64, 1), # read-only
|
||||||
|
"Status": (65, 1), # read-only
|
||||||
|
"Moving": (66, 1), # read-only
|
||||||
|
}
|
||||||
|
|
||||||
|
STS_SMS_SERIES_BAUDRATE_TABLE = {
|
||||||
|
0: 1_000_000,
|
||||||
|
1: 500_000,
|
||||||
|
2: 250_000,
|
||||||
|
3: 128_000,
|
||||||
|
4: 115_200,
|
||||||
|
5: 57_600,
|
||||||
|
6: 38_400,
|
||||||
|
7: 19_200,
|
||||||
|
}
|
||||||
|
|
||||||
SCS_SERIES_BAUDRATE_TABLE = {
|
SCS_SERIES_BAUDRATE_TABLE = {
|
||||||
0: 1_000_000,
|
0: 1_000_000,
|
||||||
1: 500_000,
|
1: 500_000,
|
||||||
@@ -66,34 +135,52 @@ SCS_SERIES_BAUDRATE_TABLE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
MODEL_CONTROL_TABLE = {
|
MODEL_CONTROL_TABLE = {
|
||||||
|
"sts_series": STS_SMS_SERIES_CONTROL_TABLE,
|
||||||
"scs_series": SCS_SERIES_CONTROL_TABLE,
|
"scs_series": SCS_SERIES_CONTROL_TABLE,
|
||||||
"sts3215": SCS_SERIES_CONTROL_TABLE,
|
"sms_series": STS_SMS_SERIES_CONTROL_TABLE,
|
||||||
|
"sts3215": STS_SMS_SERIES_CONTROL_TABLE,
|
||||||
|
"sts3250": STS_SMS_SERIES_CONTROL_TABLE,
|
||||||
|
"scs0009": SCS_SERIES_CONTROL_TABLE,
|
||||||
|
"sm8512bl": STS_SMS_SERIES_CONTROL_TABLE,
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_RESOLUTION = {
|
MODEL_RESOLUTION = {
|
||||||
"scs_series": 4096,
|
"sts_series": 4096,
|
||||||
|
"sms_series": 4096,
|
||||||
|
"scs_series": 1024,
|
||||||
"sts3215": 4096,
|
"sts3215": 4096,
|
||||||
}
|
"sts3250": 4096,
|
||||||
|
"sm8512bl": 4096,
|
||||||
# {model: model_number}
|
"scs0009": 1024,
|
||||||
MODEL_NUMBER = {
|
|
||||||
"sts3215": 777,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_BAUDRATE_TABLE = {
|
MODEL_BAUDRATE_TABLE = {
|
||||||
|
"sts_series": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||||
|
"sms_series": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||||
"scs_series": SCS_SERIES_BAUDRATE_TABLE,
|
"scs_series": SCS_SERIES_BAUDRATE_TABLE,
|
||||||
"sts3215": SCS_SERIES_BAUDRATE_TABLE,
|
"sm8512bl": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||||
|
"sts3215": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||||
|
"sts3250": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||||
|
"scs0009": SCS_SERIES_BAUDRATE_TABLE,
|
||||||
}
|
}
|
||||||
|
|
||||||
NORMALIZATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
|
||||||
|
|
||||||
# Sign-Magnitude encoding bits
|
# Sign-Magnitude encoding bits
|
||||||
ENCODINGS = {
|
STS_SMS_SERIES_ENCODINGS_TABLE = {
|
||||||
"Homing_Offset": 11,
|
"Homing_Offset": 11,
|
||||||
"Goal_Speed": 15,
|
"Goal_Speed": 15,
|
||||||
}
|
}
|
||||||
|
|
||||||
AVAILABLE_BAUDRATES = [
|
MODEL_ENCODING_TABLE = {
|
||||||
|
"sts_series": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||||
|
"sms_series": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||||
|
"scs_series": {},
|
||||||
|
"sts3215": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||||
|
"sts3250": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||||
|
"sm8512bl": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||||
|
"scs0009": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
SCAN_BAUDRATES = [
|
||||||
4_800,
|
4_800,
|
||||||
9_600,
|
9_600,
|
||||||
14_400,
|
14_400,
|
||||||
@@ -106,3 +193,11 @@ AVAILABLE_BAUDRATES = [
|
|||||||
500_000,
|
500_000,
|
||||||
1_000_000,
|
1_000_000,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# {model: model_number} TODO
|
||||||
|
MODEL_NUMBER_TABLE = {
|
||||||
|
"sts3215": 777,
|
||||||
|
"sts3250": None,
|
||||||
|
"sm8512bl": None,
|
||||||
|
"scs0009": None,
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Protocol, TypeAlias, overload
|
from typing import Protocol, TypeAlias
|
||||||
|
|
||||||
import serial
|
import serial
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
@@ -257,9 +257,10 @@ class MotorsBus(abc.ABC):
|
|||||||
default_timeout: int
|
default_timeout: int
|
||||||
model_baudrate_table: dict[str, dict]
|
model_baudrate_table: dict[str, dict]
|
||||||
model_ctrl_table: dict[str, dict]
|
model_ctrl_table: dict[str, dict]
|
||||||
|
model_encoding_table: dict[str, dict]
|
||||||
model_number_table: dict[str, int]
|
model_number_table: dict[str, int]
|
||||||
model_resolution_table: dict[str, int]
|
model_resolution_table: dict[str, int]
|
||||||
normalization_required: list[str]
|
normalized_data: list[str]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -340,6 +341,24 @@ class MotorsBus(abc.ABC):
|
|||||||
else:
|
else:
|
||||||
raise TypeError(f"'{motor}' should be int, str.")
|
raise TypeError(f"'{motor}' should be int, str.")
|
||||||
|
|
||||||
|
def _get_names_list(self, motors: str | list[str] | None) -> list[str]:
|
||||||
|
if motors is None:
|
||||||
|
return self.names
|
||||||
|
elif isinstance(motors, str):
|
||||||
|
return [motors]
|
||||||
|
elif isinstance(motors, list):
|
||||||
|
return motors.copy()
|
||||||
|
else:
|
||||||
|
raise TypeError(motors)
|
||||||
|
|
||||||
|
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
|
||||||
|
if isinstance(values, (int, float)):
|
||||||
|
return {id_: values for id_ in self.ids}
|
||||||
|
elif isinstance(values, dict):
|
||||||
|
return {self.motors[motor].id: val for motor, val in values.items()}
|
||||||
|
else:
|
||||||
|
raise TypeError(f"'values' is expected to be a single value or a dict. Got {values}")
|
||||||
|
|
||||||
def _validate_motors(self) -> None:
|
def _validate_motors(self) -> None:
|
||||||
if len(self.ids) != len(set(self.ids)):
|
if len(self.ids) != len(set(self.ids)):
|
||||||
raise ValueError(f"Some motors have the same id!\n{self}")
|
raise ValueError(f"Some motors have the same id!\n{self}")
|
||||||
@@ -632,15 +651,11 @@ class MotorsBus(abc.ABC):
|
|||||||
return unnormalized_values
|
return unnormalized_values
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _encode_value(
|
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||||
self, value: int, data_name: str | None = None, n_bytes: int | None = None
|
|
||||||
) -> dict[int, int]:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _decode_value(
|
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||||
self, value: int, data_name: str | None = None, n_bytes: int | None = None
|
|
||||||
) -> dict[int, int]:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -691,7 +706,7 @@ class MotorsBus(abc.ABC):
|
|||||||
|
|
||||||
if not self._is_comm_success(comm):
|
if not self._is_comm_success(comm):
|
||||||
if raise_on_error:
|
if raise_on_error:
|
||||||
raise ConnectionError(self.packet_handler.getRxPacketError(comm))
|
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
if self._is_error(error):
|
if self._is_error(error):
|
||||||
@@ -708,87 +723,60 @@ class MotorsBus(abc.ABC):
|
|||||||
) -> dict[int, list[int, str]] | None:
|
) -> dict[int, list[int, str]] | None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload
|
|
||||||
def sync_read(
|
|
||||||
self, data_name: str, motors: None = ..., *, normalize: bool = ..., num_retry: int = ...
|
|
||||||
) -> dict[str, Value]: ...
|
|
||||||
@overload
|
|
||||||
def sync_read(
|
def sync_read(
|
||||||
self,
|
self,
|
||||||
data_name: str,
|
data_name: str,
|
||||||
motors: NameOrID | list[NameOrID],
|
motors: str | list[str] | None = None,
|
||||||
*,
|
|
||||||
normalize: bool = ...,
|
|
||||||
num_retry: int = ...,
|
|
||||||
) -> dict[NameOrID, Value]: ...
|
|
||||||
def sync_read(
|
|
||||||
self,
|
|
||||||
data_name: str,
|
|
||||||
motors: NameOrID | list[NameOrID] | None = None,
|
|
||||||
*,
|
*,
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
num_retry: int = 0,
|
num_retry: int = 0,
|
||||||
) -> dict[NameOrID, Value]:
|
) -> dict[str, Value]:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise DeviceNotConnectedError(
|
raise DeviceNotConnectedError(
|
||||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||||
)
|
)
|
||||||
|
|
||||||
id_key_map: dict[int, NameOrID] = {}
|
names = self._get_names_list(motors)
|
||||||
if motors is None:
|
ids = [self.motors[name].id for name in names]
|
||||||
id_key_map = {m.id: name for name, m in self.motors.items()}
|
models = [self.motors[name].model for name in names]
|
||||||
elif isinstance(motors, (str, int)):
|
|
||||||
id_key_map = {self._get_motor_id(motors): motors}
|
|
||||||
elif isinstance(motors, list):
|
|
||||||
id_key_map = {self._get_motor_id(m): m for m in motors}
|
|
||||||
else:
|
|
||||||
raise TypeError(motors)
|
|
||||||
|
|
||||||
motor_ids = list(id_key_map)
|
if self._has_different_ctrl_tables:
|
||||||
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
|
|
||||||
comm, ids_values = self._sync_read(data_name, motor_ids, num_retry=num_retry)
|
model = next(iter(models))
|
||||||
|
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
|
||||||
|
|
||||||
|
comm, ids_values = self._sync_read(addr, n_bytes, ids, num_retry=num_retry)
|
||||||
if not self._is_comm_success(comm):
|
if not self._is_comm_success(comm):
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Failed to sync read '{data_name}' on {motor_ids=} after {num_retry + 1} tries."
|
f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
|
||||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if normalize and data_name in self.normalization_required:
|
ids_values = self._decode_sign(data_name, ids_values)
|
||||||
|
|
||||||
|
if normalize and data_name in self.normalized_data:
|
||||||
ids_values = self._normalize(data_name, ids_values)
|
ids_values = self._normalize(data_name, ids_values)
|
||||||
|
|
||||||
return {id_key_map[id_]: val for id_, val in ids_values.items()}
|
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
|
||||||
|
|
||||||
def _sync_read(
|
def _sync_read(
|
||||||
self, data_name: str, motor_ids: list[str], model: str | None = None, num_retry: int = 0
|
self, addr: int, n_bytes: int, motor_ids: list[int], num_retry: int = 0
|
||||||
) -> tuple[int, dict[int, int]]:
|
) -> tuple[int, dict[int, int]]:
|
||||||
if self._has_different_ctrl_tables:
|
|
||||||
models = [self._id_to_model(id_) for id_ in motor_ids]
|
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
|
||||||
|
|
||||||
model = self._id_to_model(next(iter(motor_ids))) if model is None else model
|
|
||||||
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
|
|
||||||
self._setup_sync_reader(motor_ids, addr, n_bytes)
|
self._setup_sync_reader(motor_ids, addr, n_bytes)
|
||||||
|
|
||||||
# FIXME(aliberts, pkooij): We should probably not have to do this.
|
|
||||||
# Let's try to see if we can do with better comm status handling instead.
|
|
||||||
# self.port_handler.ser.reset_output_buffer()
|
|
||||||
# self.port_handler.ser.reset_input_buffer()
|
|
||||||
|
|
||||||
for n_try in range(1 + num_retry):
|
for n_try in range(1 + num_retry):
|
||||||
comm = self.sync_reader.txRxPacket()
|
comm = self.sync_reader.txRxPacket()
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
logger.debug(f"Failed to sync read '{data_name}' ({addr=} {n_bytes=}) on {motor_ids=} ({n_try=})")
|
logger.debug(
|
||||||
logger.debug(self.packet_handler.getRxPacketError(comm))
|
f"Failed to sync read @{addr=} ({n_bytes=}) on {motor_ids=} ({n_try=}): "
|
||||||
|
+ self.packet_handler.getTxRxResult(comm)
|
||||||
values = {}
|
)
|
||||||
for id_ in motor_ids:
|
|
||||||
val = self.sync_reader.getData(id_, addr, n_bytes)
|
|
||||||
values[id_] = self._decode_value(val, data_name, n_bytes)
|
|
||||||
|
|
||||||
|
values = {id_: self.sync_reader.getData(id_, addr, n_bytes) for id_ in motor_ids}
|
||||||
return comm, values
|
return comm, values
|
||||||
|
|
||||||
def _setup_sync_reader(self, motor_ids: list[str], addr: int, n_bytes: int) -> None:
|
def _setup_sync_reader(self, motor_ids: list[int], addr: int, n_bytes: int) -> None:
|
||||||
self.sync_reader.clearParam()
|
self.sync_reader.clearParam()
|
||||||
self.sync_reader.start_address = addr
|
self.sync_reader.start_address = addr
|
||||||
self.sync_reader.data_length = n_bytes
|
self.sync_reader.data_length = n_bytes
|
||||||
@@ -799,7 +787,7 @@ class MotorsBus(abc.ABC):
|
|||||||
# Would have to handle the logic of checking if a packet has been sent previously though but doable.
|
# Would have to handle the logic of checking if a packet has been sent previously though but doable.
|
||||||
# This could be at the cost of increase latency between the moment the data is produced by the motors and
|
# This could be at the cost of increase latency between the moment the data is produced by the motors and
|
||||||
# the moment it is used by a policy.
|
# the moment it is used by a policy.
|
||||||
# def _async_read(self, motor_ids: list[str], address: int, n_bytes: int):
|
# def _async_read(self, motor_ids: list[int], address: int, n_bytes: int):
|
||||||
# if self.sync_reader.start_address != address or self.sync_reader.data_length != n_bytes or ...:
|
# if self.sync_reader.start_address != address or self.sync_reader.data_length != n_bytes or ...:
|
||||||
# self._setup_sync_reader(motor_ids, address, n_bytes)
|
# self._setup_sync_reader(motor_ids, address, n_bytes)
|
||||||
# else:
|
# else:
|
||||||
@@ -812,7 +800,7 @@ class MotorsBus(abc.ABC):
|
|||||||
def sync_write(
|
def sync_write(
|
||||||
self,
|
self,
|
||||||
data_name: str,
|
data_name: str,
|
||||||
values: Value | dict[NameOrID, Value],
|
values: Value | dict[str, Value],
|
||||||
*,
|
*,
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
num_retry: int = 0,
|
num_retry: int = 0,
|
||||||
@@ -822,41 +810,36 @@ class MotorsBus(abc.ABC):
|
|||||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(values, int):
|
ids_values = self._get_ids_values_dict(values)
|
||||||
ids_values = {id_: values for id_ in self.ids}
|
models = [self._id_to_model(id_) for id_ in ids_values]
|
||||||
elif isinstance(values, dict):
|
if self._has_different_ctrl_tables:
|
||||||
ids_values = {self._get_motor_id(motor): val for motor, val in values.items()}
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
else:
|
|
||||||
raise TypeError(f"'values' is expected to be a single value or a dict. Got {values}")
|
|
||||||
|
|
||||||
if normalize and data_name in self.normalization_required and self.calibration is not None:
|
model = next(iter(models))
|
||||||
|
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
|
||||||
|
|
||||||
|
if normalize and data_name in self.normalized_data:
|
||||||
ids_values = self._unnormalize(data_name, ids_values)
|
ids_values = self._unnormalize(data_name, ids_values)
|
||||||
|
|
||||||
comm = self._sync_write(data_name, ids_values, num_retry=num_retry)
|
ids_values = self._encode_sign(data_name, ids_values)
|
||||||
|
|
||||||
|
comm = self._sync_write(addr, n_bytes, ids_values, num_retry=num_retry)
|
||||||
if not self._is_comm_success(comm):
|
if not self._is_comm_success(comm):
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
||||||
f"\n{self.packet_handler.getTxRxResult(comm)}"
|
f"\n{self.packet_handler.getTxRxResult(comm)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _sync_write(self, data_name: str, ids_values: dict[int, int], num_retry: int = 0) -> int:
|
def _sync_write(self, addr: int, n_bytes: int, ids_values: dict[int, int], num_retry: int = 0) -> int:
|
||||||
if self._has_different_ctrl_tables:
|
|
||||||
models = [self._id_to_model(id_) for id_ in ids_values]
|
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
|
||||||
|
|
||||||
model = self._id_to_model(next(iter(ids_values)))
|
|
||||||
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
|
|
||||||
ids_values = {id_: self._encode_value(value, data_name, n_bytes) for id_, value in ids_values.items()}
|
|
||||||
self._setup_sync_writer(ids_values, addr, n_bytes)
|
self._setup_sync_writer(ids_values, addr, n_bytes)
|
||||||
|
|
||||||
for n_try in range(1 + num_retry):
|
for n_try in range(1 + num_retry):
|
||||||
comm = self.sync_writer.txPacket()
|
comm = self.sync_writer.txPacket()
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Failed to sync write '{data_name}' ({addr=} {n_bytes=}) with {ids_values=} ({n_try=})"
|
f"Failed to sync write @{addr=} ({n_bytes=}) with {ids_values=} ({n_try=}): "
|
||||||
|
+ self.packet_handler.getTxRxResult(comm)
|
||||||
)
|
)
|
||||||
logger.debug(self.packet_handler.getRxPacketError(comm))
|
|
||||||
|
|
||||||
return comm
|
return comm
|
||||||
|
|
||||||
@@ -869,20 +852,23 @@ class MotorsBus(abc.ABC):
|
|||||||
self.sync_writer.addParam(id_, data)
|
self.sync_writer.addParam(id_, data)
|
||||||
|
|
||||||
def write(
|
def write(
|
||||||
self, data_name: str, motor: NameOrID, value: Value, *, normalize: bool = True, num_retry: int = 0
|
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise DeviceNotConnectedError(
|
raise DeviceNotConnectedError(
|
||||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||||
)
|
)
|
||||||
|
|
||||||
id_ = self._get_motor_id(motor)
|
id_ = self.motors[motor].id
|
||||||
|
model = self.motors[motor].model
|
||||||
|
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
|
||||||
|
|
||||||
if normalize and data_name in self.normalization_required and self.calibration is not None:
|
if normalize and data_name in self.normalized_data:
|
||||||
id_value = self._unnormalize(data_name, {id_: value})
|
value = self._unnormalize(data_name, {id_: value})[id_]
|
||||||
value = id_value[id_]
|
|
||||||
|
|
||||||
comm, error = self._write(data_name, id_, value, num_retry=num_retry)
|
value = self._encode_sign(data_name, {id_: value})[id_]
|
||||||
|
|
||||||
|
comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry)
|
||||||
if not self._is_comm_success(comm):
|
if not self._is_comm_success(comm):
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
||||||
@@ -894,20 +880,18 @@ class MotorsBus(abc.ABC):
|
|||||||
f"\n{self.packet_handler.getRxPacketError(error)}"
|
f"\n{self.packet_handler.getRxPacketError(error)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _write(self, data_name: str, motor_id: int, value: int, num_retry: int = 0) -> tuple[int, int]:
|
def _write(
|
||||||
model = self._id_to_model(motor_id)
|
self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0
|
||||||
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
|
) -> tuple[int, int]:
|
||||||
value = self._encode_value(value, data_name, n_bytes)
|
|
||||||
data = self._split_int_to_bytes(value, n_bytes)
|
data = self._split_int_to_bytes(value, n_bytes)
|
||||||
|
|
||||||
for n_try in range(1 + num_retry):
|
for n_try in range(1 + num_retry):
|
||||||
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data)
|
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data)
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Failed to write '{data_name}' ({addr=} {n_bytes=}) on {motor_id=} with '{value}' ({n_try=})"
|
f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): "
|
||||||
|
+ self.packet_handler.getTxRxResult(comm)
|
||||||
)
|
)
|
||||||
logger.debug(self.packet_handler.getRxPacketError(comm))
|
|
||||||
|
|
||||||
return comm, error
|
return comm, error
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import scservo_sdk as scs
|
|||||||
import serial
|
import serial
|
||||||
from mock_serial import MockSerial
|
from mock_serial import MockSerial
|
||||||
|
|
||||||
from lerobot.common.motors.feetech import SCS_SERIES_CONTROL_TABLE, FeetechMotorsBus
|
from lerobot.common.motors.feetech import STS_SMS_SERIES_CONTROL_TABLE, FeetechMotorsBus
|
||||||
from lerobot.common.motors.feetech.feetech import patch_setPacketTimeout
|
from lerobot.common.motors.feetech.feetech import patch_setPacketTimeout
|
||||||
|
|
||||||
from .mock_serial_patch import WaitableStub
|
from .mock_serial_patch import WaitableStub
|
||||||
@@ -297,7 +297,7 @@ class MockMotors(MockSerial):
|
|||||||
instruction packets. It is meant to test MotorsBus classes.
|
instruction packets. It is meant to test MotorsBus classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ctrl_table = SCS_SERIES_CONTROL_TABLE
|
ctrl_table = STS_SMS_SERIES_CONTROL_TABLE
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -338,12 +338,6 @@ class MockMotors(MockSerial):
|
|||||||
def build_read_stub(
|
def build_read_stub(
|
||||||
self, data_name: str, scs_id: int, value: int | None = None, num_invalid_try: int = 0
|
self, data_name: str, scs_id: int, value: int | None = None, num_invalid_try: int = 0
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
|
||||||
'data_name' supported:
|
|
||||||
- Model_Number
|
|
||||||
"""
|
|
||||||
if data_name != "Model_Number":
|
|
||||||
raise NotImplementedError
|
|
||||||
address, length = self.ctrl_table[data_name]
|
address, length = self.ctrl_table[data_name]
|
||||||
read_request = MockInstructionPacket.read(scs_id, address, length)
|
read_request = MockInstructionPacket.read(scs_id, address, length)
|
||||||
return_packet = MockStatusPacket.read(scs_id, value, length)
|
return_packet = MockStatusPacket.read(scs_id, value, length)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import pytest
|
|||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
from lerobot.common.motors.feetech import MODEL_NUMBER, FeetechMotorsBus
|
from lerobot.common.motors.feetech import MODEL_NUMBER_TABLE, FeetechMotorsBus
|
||||||
from lerobot.common.utils.encoding_utils import encode_sign_magnitude
|
from lerobot.common.utils.encoding_utils import encode_sign_magnitude
|
||||||
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
|
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ def test_scan_port(mock_motors):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("id_", [1, 2, 3])
|
@pytest.mark.parametrize("id_", [1, 2, 3])
|
||||||
def test_ping(id_, mock_motors, dummy_motors):
|
def test_ping(id_, mock_motors, dummy_motors):
|
||||||
expected_model_nb = MODEL_NUMBER[dummy_motors[f"dummy_{id_}"].model]
|
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
|
||||||
ping_stub = mock_motors.build_ping_stub(id_)
|
ping_stub = mock_motors.build_ping_stub(id_)
|
||||||
mobel_nb_stub = mock_motors.build_read_stub("Model_Number", id_, expected_model_nb)
|
mobel_nb_stub = mock_motors.build_read_stub("Model_Number", id_, expected_model_nb)
|
||||||
motors_bus = FeetechMotorsBus(
|
motors_bus = FeetechMotorsBus(
|
||||||
@@ -147,7 +147,7 @@ def test_ping(id_, mock_motors, dummy_motors):
|
|||||||
|
|
||||||
def test_broadcast_ping(mock_motors, dummy_motors):
|
def test_broadcast_ping(mock_motors, dummy_motors):
|
||||||
models = {m.id: m.model for m in dummy_motors.values()}
|
models = {m.id: m.model for m in dummy_motors.values()}
|
||||||
expected_model_nbs = {id_: MODEL_NUMBER[model] for id_, model in models.items()}
|
expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
|
||||||
ping_stub = mock_motors.build_broadcast_ping_stub(list(models))
|
ping_stub = mock_motors.build_broadcast_ping_stub(list(models))
|
||||||
mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs)
|
mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs)
|
||||||
motors_bus = FeetechMotorsBus(
|
motors_bus = FeetechMotorsBus(
|
||||||
@@ -191,55 +191,7 @@ def test_sync_read_none(mock_motors, dummy_motors):
|
|||||||
(3, 4016),
|
(3, 4016),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sync_read_by_id(id_, position, mock_motors, dummy_motors):
|
def test_sync_read_single_value(id_, position, mock_motors, dummy_motors):
|
||||||
expected_position = {id_: position}
|
|
||||||
stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_position)
|
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
read_position = motors_bus.sync_read("Present_Position", id_, normalize=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].called
|
|
||||||
assert read_position == expected_position
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"ids, positions",
|
|
||||||
[
|
|
||||||
([1], [1337]),
|
|
||||||
([1, 2], [1337, 42]),
|
|
||||||
([1, 2, 3], [1337, 42, 4016]),
|
|
||||||
],
|
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
|
||||||
) # fmt: skip
|
|
||||||
def test_sync_read_by_ids(ids, positions, mock_motors, dummy_motors):
|
|
||||||
assert len(ids) == len(positions)
|
|
||||||
expected_positions = dict(zip(ids, positions, strict=True))
|
|
||||||
stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_positions)
|
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
read_positions = motors_bus.sync_read("Present_Position", ids, normalize=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].called
|
|
||||||
assert read_positions == expected_positions
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"id_, position",
|
|
||||||
[
|
|
||||||
(1, 1337),
|
|
||||||
(2, 42),
|
|
||||||
(3, 4016),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sync_read_by_name(id_, position, mock_motors, dummy_motors):
|
|
||||||
expected_position = {f"dummy_{id_}": position}
|
expected_position = {f"dummy_{id_}": position}
|
||||||
stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position})
|
stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position})
|
||||||
motors_bus = FeetechMotorsBus(
|
motors_bus = FeetechMotorsBus(
|
||||||
@@ -263,7 +215,7 @@ def test_sync_read_by_name(id_, position, mock_motors, dummy_motors):
|
|||||||
],
|
],
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
ids=["1 motor", "2 motors", "3 motors"],
|
||||||
) # fmt: skip
|
) # fmt: skip
|
||||||
def test_sync_read_by_names(ids, positions, mock_motors, dummy_motors):
|
def test_sync_read(ids, positions, mock_motors, dummy_motors):
|
||||||
assert len(ids) == len(positions)
|
assert len(ids) == len(positions)
|
||||||
names = [f"dummy_{dxl_id}" for dxl_id in ids]
|
names = [f"dummy_{dxl_id}" for dxl_id in ids]
|
||||||
expected_positions = dict(zip(names, positions, strict=True))
|
expected_positions = dict(zip(names, positions, strict=True))
|
||||||
@@ -291,9 +243,9 @@ def test_sync_read_by_names(ids, positions, mock_motors, dummy_motors):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors):
|
def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors):
|
||||||
expected_position = {1: pos}
|
expected_position = {"dummy_1": pos}
|
||||||
stub_name = mock_motors.build_sync_read_stub(
|
stub_name = mock_motors.build_sync_read_stub(
|
||||||
"Present_Position", expected_position, num_invalid_try=num_invalid_try
|
"Present_Position", {1: pos}, num_invalid_try=num_invalid_try
|
||||||
)
|
)
|
||||||
motors_bus = FeetechMotorsBus(
|
motors_bus = FeetechMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
@@ -302,11 +254,11 @@ def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy
|
|||||||
motors_bus.connect(assert_motors_exist=False)
|
motors_bus.connect(assert_motors_exist=False)
|
||||||
|
|
||||||
if num_retry >= num_invalid_try:
|
if num_retry >= num_invalid_try:
|
||||||
pos_dict = motors_bus.sync_read("Present_Position", 1, normalize=False, num_retry=num_retry)
|
pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
|
||||||
assert pos_dict == {1: pos}
|
assert pos_dict == expected_position
|
||||||
else:
|
else:
|
||||||
with pytest.raises(ConnectionError):
|
with pytest.raises(ConnectionError):
|
||||||
_ = motors_bus.sync_read("Present_Position", 1, normalize=False, num_retry=num_retry)
|
_ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
|
||||||
|
|
||||||
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
|
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
|
||||||
assert mock_motors.stubs[stub_name].calls == expected_calls
|
assert mock_motors.stubs[stub_name].calls == expected_calls
|
||||||
@@ -335,28 +287,6 @@ def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors):
|
|||||||
assert mock_motors.stubs[stub_name].wait_called()
|
assert mock_motors.stubs[stub_name].wait_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"id_, position",
|
|
||||||
[
|
|
||||||
(1, 1337),
|
|
||||||
(2, 42),
|
|
||||||
(3, 4016),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sync_write_by_id(id_, position, mock_motors, dummy_motors):
|
|
||||||
value = {id_: position}
|
|
||||||
stub_name = mock_motors.build_sync_write_stub("Goal_Position", value)
|
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
motors_bus.sync_write("Goal_Position", value, normalize=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].wait_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ids, positions",
|
"ids, positions",
|
||||||
[
|
[
|
||||||
@@ -366,54 +296,7 @@ def test_sync_write_by_id(id_, position, mock_motors, dummy_motors):
|
|||||||
],
|
],
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
ids=["1 motor", "2 motors", "3 motors"],
|
||||||
) # fmt: skip
|
) # fmt: skip
|
||||||
def test_sync_write_by_ids(ids, positions, mock_motors, dummy_motors):
|
def test_sync_write(ids, positions, mock_motors, dummy_motors):
|
||||||
assert len(ids) == len(positions)
|
|
||||||
values = dict(zip(ids, positions, strict=True))
|
|
||||||
stub_name = mock_motors.build_sync_write_stub("Goal_Position", values)
|
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
motors_bus.sync_write("Goal_Position", values, normalize=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].wait_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"id_, position",
|
|
||||||
[
|
|
||||||
(1, 1337),
|
|
||||||
(2, 42),
|
|
||||||
(3, 4016),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sync_write_by_name(id_, position, mock_motors, dummy_motors):
|
|
||||||
id_value = {id_: position}
|
|
||||||
stub_name = mock_motors.build_sync_write_stub("Goal_Position", id_value)
|
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
write_value = {f"dummy_{id_}": position}
|
|
||||||
motors_bus.sync_write("Goal_Position", write_value, normalize=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].wait_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"ids, positions",
|
|
||||||
[
|
|
||||||
([1], [1337]),
|
|
||||||
([1, 2], [1337, 42]),
|
|
||||||
([1, 2, 3], [1337, 42, 4016]),
|
|
||||||
],
|
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
|
||||||
) # fmt: skip
|
|
||||||
def test_sync_write_by_names(ids, positions, mock_motors, dummy_motors):
|
|
||||||
assert len(ids) == len(positions)
|
assert len(ids) == len(positions)
|
||||||
ids_values = dict(zip(ids, positions, strict=True))
|
ids_values = dict(zip(ids, positions, strict=True))
|
||||||
stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values)
|
stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values)
|
||||||
@@ -438,29 +321,7 @@ def test_sync_write_by_names(ids, positions, mock_motors, dummy_motors):
|
|||||||
("Goal_Position", 3, 42),
|
("Goal_Position", 3, 42),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_write_by_id(data_name, dxl_id, value, mock_motors, dummy_motors):
|
def test_write(data_name, dxl_id, value, mock_motors, dummy_motors):
|
||||||
stub_name = mock_motors.build_write_stub(data_name, dxl_id, value)
|
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
motors_bus.write(data_name, dxl_id, value, normalize=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"data_name, dxl_id, value",
|
|
||||||
[
|
|
||||||
("Torque_Enable", 1, 0),
|
|
||||||
("Torque_Enable", 1, 1),
|
|
||||||
("Goal_Position", 2, 1337),
|
|
||||||
("Goal_Position", 3, 42),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_write_by_name(data_name, dxl_id, value, mock_motors, dummy_motors):
|
|
||||||
stub_name = mock_motors.build_write_stub(data_name, dxl_id, value)
|
stub_name = mock_motors.build_write_stub(data_name, dxl_id, value)
|
||||||
motors_bus = FeetechMotorsBus(
|
motors_bus = FeetechMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
|
|||||||
Reference in New Issue
Block a user