Add more segmented tests (base motor bus & feetech), add feetech protocol 1 support

This commit is contained in:
Simon Alibert
2025-04-14 11:56:53 +02:00
parent e0b292ab51
commit bdbca09cb2
6 changed files with 749 additions and 350 deletions

View File

@@ -10,27 +10,6 @@ from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch
from .mock_serial_patch import WaitableStub
# https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf
INSTRUCTION_TYPES = {
"Read": scs.INST_PING, # Read data from the Device
"Ping": scs.INST_READ, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID
"Write": scs.INST_WRITE, # Write data to the Device
"Reg_Write": scs.INST_REG_WRITE, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command
"Action": scs.INST_ACTION, # Executes a Packet that was registered beforehand using Reg Write
"Factory_Reset": 0x06, # Resets the Control Table to its initial factory default settings
"Sync_Write": scs.INST_SYNC_WRITE, # Write data to multiple devices with the same Address with the same length at once
"Sync_Read": scs.INST_SYNC_READ, # Read data from multiple devices with the same Address with the same length at once
} # fmt: skip
ERROR_TYPE = {
"Success": 0x00,
"Voltage": scs.ERRBIT_VOLTAGE,
"Angle": scs.ERRBIT_ANGLE,
"Overheat": scs.ERRBIT_OVERHEAT,
"Overele": scs.ERRBIT_OVERELE,
"Overload": scs.ERRBIT_OVERLOAD,
}
class MockFeetechPacket(abc.ABC):
@classmethod
@@ -68,15 +47,14 @@ class MockInstructionPacket(MockFeetechPacket):
"""
@classmethod
def _build(cls, scs_id: int, params: list[int], length: int, instruct_type: str) -> list[int]:
instruct_value = INSTRUCTION_TYPES[instruct_type]
def _build(cls, scs_id: int, params: list[int], length: int, instruction: int) -> list[int]:
return [
0xFF, 0xFF, # header
scs_id, # servo id
length, # length
instruct_value, # instruction type
*params, # data bytes
0x00, # placeholder for checksum
0xFF, 0xFF, # header
scs_id, # servo id
length, # length
instruction, # instruction type
*params, # data bytes
0x00, # placeholder for checksum
] # fmt: skip
@classmethod
@@ -89,7 +67,7 @@ class MockInstructionPacket(MockFeetechPacket):
No parameters required.
"""
return cls.build(scs_id=scs_id, params=[], length=2, instruct_type="Ping")
return cls.build(scs_id=scs_id, params=[], length=2, instruction=scs.INST_PING)
@classmethod
def read(
@@ -113,7 +91,7 @@ class MockInstructionPacket(MockFeetechPacket):
"""
params = [start_address, data_length]
length = 4
return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Read")
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_READ)
@classmethod
def write(
@@ -142,7 +120,7 @@ class MockInstructionPacket(MockFeetechPacket):
data = _split_into_byte_chunks(value, data_length)
params = [start_address, *data]
length = data_length + 3
return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write")
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_WRITE)
@classmethod
def sync_read(
@@ -167,7 +145,9 @@ class MockInstructionPacket(MockFeetechPacket):
"""
params = [start_address, data_length, *scs_ids]
length = len(scs_ids) + 4
return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read")
return cls.build(
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_READ
)
@classmethod
def sync_write(
@@ -205,7 +185,9 @@ class MockInstructionPacket(MockFeetechPacket):
data += [id_, *split_value]
params = [start_address, data_length, *data]
length = len(ids_values) * (1 + data_length) + 4
return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write")
return cls.build(
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_WRITE
)
class MockStatusPacket(MockFeetechPacket):
@@ -222,19 +204,18 @@ class MockStatusPacket(MockFeetechPacket):
"""
@classmethod
def _build(cls, scs_id: int, params: list[int], length: int, error: str = "Success") -> list[int]:
err_byte = ERROR_TYPE[error]
def _build(cls, scs_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
return [
0xFF, 0xFF, # header
scs_id, # servo id
length, # length
err_byte, # status
error, # status
*params, # data bytes
0x00, # placeholder for checksum
] # fmt: skip
@classmethod
def ping(cls, scs_id: int, error: str = "Success") -> bytes:
def ping(cls, scs_id: int, error: int = 0) -> bytes:
"""Builds a 'Ping' status packet.
Args:
@@ -247,7 +228,7 @@ class MockStatusPacket(MockFeetechPacket):
return cls.build(scs_id, params=[], length=2, error=error)
@classmethod
def read(cls, scs_id: int, value: int, param_length: int) -> bytes:
def read(cls, scs_id: int, value: int, param_length: int, error: int = 0) -> bytes:
"""Builds a 'Read' status packet.
Args:
@@ -260,7 +241,7 @@ class MockStatusPacket(MockFeetechPacket):
"""
params = _split_into_byte_chunks(value, param_length)
length = param_length + 2
return cls.build(scs_id, params=params, length=length)
return cls.build(scs_id, params=params, length=length, error=error)
class MockPortHandler(scs.PortHandler):
@@ -323,11 +304,11 @@ class MockMotors(MockSerial):
)
return stub_name
def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0) -> str:
def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0, error: int = 0) -> str:
ping_request = MockInstructionPacket.ping(scs_id)
return_packet = MockStatusPacket.ping(scs_id)
return_packet = MockStatusPacket.ping(scs_id, error)
ping_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Ping_{scs_id}"
stub_name = f"Ping_{scs_id}_{error}"
self.stub(
name=stub_name,
receive_bytes=ping_request,
@@ -336,13 +317,19 @@ class MockMotors(MockSerial):
return stub_name
def build_read_stub(
self, data_name: str, scs_id: int, value: int | None = None, num_invalid_try: int = 0
self,
address: int,
length: int,
scs_id: int,
value: int,
reply: bool = True,
error: int = 0,
num_invalid_try: int = 0,
) -> str:
address, length = self.ctrl_table[data_name]
read_request = MockInstructionPacket.read(scs_id, address, length)
return_packet = MockStatusPacket.read(scs_id, value, length)
return_packet = MockStatusPacket.read(scs_id, value, length, error) if reply else b""
read_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Read_{data_name}_{scs_id}"
stub_name = f"Read_{address}_{length}_{scs_id}_{value}_{error}"
self.stub(
name=stub_name,
receive_bytes=read_request,
@@ -350,15 +337,42 @@ class MockMotors(MockSerial):
)
return stub_name
def build_sync_read_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
def build_write_stub(
self,
address: int,
length: int,
scs_id: int,
value: int,
reply: bool = True,
error: int = 0,
num_invalid_try: int = 0,
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
return_packets = b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
sync_read_request = MockInstructionPacket.write(scs_id, value, address, length)
return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) if reply else b""
stub_name = f"Write_{address}_{length}_{scs_id}"
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(return_packet, num_invalid_try),
)
return stub_name
def build_sync_read_stub(
self,
address: int,
length: int,
ids_values: dict[int, int],
reply: bool = True,
num_invalid_try: int = 0,
) -> str:
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
return_packets = (
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
if reply
else b""
)
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
stub_name = f"Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
@@ -367,11 +381,10 @@ class MockMotors(MockSerial):
return stub_name
def build_sequential_sync_read_stub(
self, data_name: str, ids_values: dict[int, list[int]] | None = None
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
) -> str:
sequence_length = len(next(iter(ids_values.values())))
assert all(len(positions) == sequence_length for positions in ids_values.values())
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
sequential_packets = []
for count in range(sequence_length):
@@ -381,7 +394,7 @@ class MockMotors(MockSerial):
sequential_packets.append(return_packets)
sync_read_response = self._build_sequential_send_fn(sequential_packets)
stub_name = f"Seq_Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
@@ -390,11 +403,10 @@ class MockMotors(MockSerial):
return stub_name
def build_sync_write_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
stub_name = f"Sync_Write_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
@@ -402,20 +414,6 @@ class MockMotors(MockSerial):
)
return stub_name
def build_write_stub(
self, data_name: str, scs_id: int, value: int, error: str = "Success", num_invalid_try: int = 0
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.write(scs_id, value, address, length)
return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error)
stub_name = f"Write_{data_name}_{scs_id}"
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(return_packet, num_invalid_try),
)
return stub_name
@staticmethod
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes:

View File

@@ -1,3 +1,4 @@
import re
import sys
from typing import Generator
from unittest.mock import MagicMock, patch
@@ -6,7 +7,8 @@ import pytest
import scservo_sdk as scs
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.common.motors.feetech import MODEL_NUMBER_TABLE, FeetechMotorsBus
from lerobot.common.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus
from lerobot.common.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE
from lerobot.common.utils.encoding_utils import encode_sign_magnitude
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
@@ -109,8 +111,9 @@ def test_scan_port(mock_motors):
@pytest.mark.parametrize("id_", [1, 2, 3])
def test_ping(id_, mock_motors, dummy_motors):
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
addr, length = MODEL_NUMBER
ping_stub = mock_motors.build_ping_stub(id_)
mobel_nb_stub = mock_motors.build_read_stub("Model_Number", id_, expected_model_nb)
mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
@@ -126,9 +129,15 @@ def test_ping(id_, mock_motors, dummy_motors):
def test_broadcast_ping(mock_motors, dummy_motors):
models = {m.id: m.model for m in dummy_motors.values()}
expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
addr, length = MODEL_NUMBER
ping_stub = mock_motors.build_broadcast_ping_stub(list(models))
mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs)
mobel_nb_stubs = []
expected_model_nbs = {}
for id_, model in models.items():
model_nb = MODEL_NUMBER_TABLE[model]
stub = mock_motors.build_read_stub(addr, length, id_, model_nb)
expected_model_nbs[id_] = model_nb
mobel_nb_stubs.append(stub)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
@@ -139,187 +148,209 @@ def test_broadcast_ping(mock_motors, dummy_motors):
assert ping_model_nbs == expected_model_nbs
assert mock_motors.stubs[ping_stub].called
assert mock_motors.stubs[mobel_nb_stub].called
def test_sync_read_none(mock_motors, dummy_motors):
expected_positions = {
"dummy_1": 1337,
"dummy_2": 42,
"dummy_3": 4016,
}
ids_values = dict(zip([1, 2, 3], expected_positions.values(), strict=True))
stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
read_positions = motors_bus.sync_read("Present_Position", normalize=False)
assert mock_motors.stubs[stub_name].called
assert read_positions == expected_positions
assert all(mock_motors.stubs[stub].called for stub in mobel_nb_stubs)
@pytest.mark.parametrize(
"id_, position",
"addr, length, id_, value",
[
(1, 1337),
(2, 42),
(3, 4016),
(0, 1, 1, 2),
(10, 2, 2, 999),
(42, 4, 3, 1337),
],
)
def test_sync_read_single_value(id_, position, mock_motors, dummy_motors):
expected_position = {f"dummy_{id_}": position}
stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position})
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
stub_name = mock_motors.build_read_stub(addr, length, id_, value)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
read_position = motors_bus.sync_read("Present_Position", f"dummy_{id_}", normalize=False)
read_value, _, _ = motors_bus._read(addr, length, id_)
assert mock_motors.stubs[stub_name].called
assert read_position == expected_position
assert read_value == value
@pytest.mark.parametrize(
"ids, positions",
[
([1], [1337]),
([1, 2], [1337, 42]),
([1, 2, 3], [1337, 42, 4016]),
],
ids=["1 motor", "2 motors", "3 motors"],
) # fmt: skip
def test_sync_read(ids, positions, mock_motors, dummy_motors):
assert len(ids) == len(positions)
names = [f"dummy_{dxl_id}" for dxl_id in ids]
expected_positions = dict(zip(names, positions, strict=True))
ids_values = dict(zip(ids, positions, strict=True))
stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values)
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__read_error(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
stub_name = mock_motors.build_read_stub(addr, length, id_, value, error=error)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
read_positions = motors_bus.sync_read("Present_Position", names, normalize=False)
assert mock_motors.stubs[stub_name].called
assert read_positions == expected_positions
@pytest.mark.parametrize(
"num_retry, num_invalid_try, pos",
[
(0, 2, 1337),
(2, 3, 42),
(3, 2, 4016),
(2, 1, 999),
],
)
def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors):
expected_position = {"dummy_1": pos}
stub_name = mock_motors.build_sync_read_stub(
"Present_Position", {1: pos}, num_invalid_try=num_invalid_try
)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
if num_retry >= num_invalid_try:
pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
assert pos_dict == expected_position
if raise_on_error:
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
motors_bus._read(addr, length, id_, raise_on_error=raise_on_error)
else:
with pytest.raises(ConnectionError):
_ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
assert mock_motors.stubs[stub_name].calls == expected_calls
@pytest.mark.parametrize(
"data_name, value",
[
("Torque_Enable", 0),
("Torque_Enable", 1),
("Goal_Position", 1337),
("Goal_Position", 42),
],
)
def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors):
ids_values = {m.id: value for m in dummy_motors.values()}
stub_name = mock_motors.build_sync_write_stub(data_name, ids_values)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
motors_bus.sync_write(data_name, value, normalize=False)
assert mock_motors.stubs[stub_name].wait_called()
@pytest.mark.parametrize(
"ids, positions",
[
([1], [1337]),
([1, 2], [1337, 42]),
([1, 2, 3], [1337, 42, 4016]),
],
ids=["1 motor", "2 motors", "3 motors"],
) # fmt: skip
def test_sync_write(ids, positions, mock_motors, dummy_motors):
assert len(ids) == len(positions)
ids_values = dict(zip(ids, positions, strict=True))
stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
write_values = {f"dummy_{id_}": pos for id_, pos in ids_values.items()}
motors_bus.sync_write("Goal_Position", write_values, normalize=False)
assert mock_motors.stubs[stub_name].wait_called()
@pytest.mark.parametrize(
"data_name, dxl_id, value",
[
("Torque_Enable", 1, 0),
("Torque_Enable", 1, 1),
("Goal_Position", 2, 1337),
("Goal_Position", 3, 42),
],
)
def test_write(data_name, dxl_id, value, mock_motors, dummy_motors):
stub_name = mock_motors.build_write_stub(data_name, dxl_id, value)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
motors_bus.write(data_name, f"dummy_{dxl_id}", value, normalize=False)
_, _, read_error = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error)
assert read_error == error
assert mock_motors.stubs[stub_name].called
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value = (10, 4, 1, 1337)
stub_name = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
motors_bus._read(addr, length, id_, raise_on_error=raise_on_error)
else:
_, read_comm, _ = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error)
assert read_comm == scs.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub_name].called
@pytest.mark.parametrize(
"addr, length, id_, value",
[
(0, 1, 1, 2),
(10, 2, 2, 999),
(42, 4, 3, 1337),
],
)
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
stub_name = mock_motors.build_write_stub(addr, length, id_, value)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
comm, error = motors_bus._write(addr, length, id_, value)
assert mock_motors.stubs[stub_name].called
assert comm == scs.COMM_SUCCESS
assert error == 0
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__write_error(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
stub_name = mock_motors.build_write_stub(addr, length, id_, value, error=error)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
if raise_on_error:
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
else:
_, write_error = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
assert write_error == error
assert mock_motors.stubs[stub_name].called
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value = (10, 4, 1, 1337)
stub_name = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
else:
write_comm, _ = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
assert write_comm == scs.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub_name].called
@pytest.mark.parametrize(
"addr, length, ids_values",
[
(0, 1, {1: 4}),
(10, 2, {1: 1337, 2: 42}),
(42, 4, {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
stub_name = mock_motors.build_sync_read_stub(addr, length, ids_values)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
read_values, _ = motors_bus._sync_read(addr, length, list(ids_values))
assert mock_motors.stubs[stub_name].called
assert read_values == ids_values
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, ids_values = (10, 4, {1: 1337})
stub_name = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
else:
_, read_comm = motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
assert read_comm == scs.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub_name].called
@pytest.mark.parametrize(
"addr, length, ids_values",
[
(0, 1, {1: 4}),
(10, 2, {1: 1337, 2: 42}),
(42, 4, {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
stub_name = mock_motors.build_sync_write_stub(addr, length, ids_values)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
comm = motors_bus._sync_write(addr, length, ids_values)
assert mock_motors.stubs[stub_name].wait_called()
assert comm == scs.COMM_SUCCESS
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
encoded_homings = {m.id: encode_sign_magnitude(m.homing_offset, 11) for m in dummy_calibration.values()}
mins = {m.id: m.range_min for m in dummy_calibration.values()}
maxes = {m.id: m.range_max for m in dummy_calibration.values()}
offsets_stub = mock_motors.build_sync_read_stub("Homing_Offset", encoded_homings)
mins_stub = mock_motors.build_sync_read_stub("Min_Position_Limit", mins)
maxes_stub = mock_motors.build_sync_read_stub("Max_Position_Limit", maxes)
offsets_stub = mock_motors.build_sync_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings
)
mins_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins)
maxes_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
@@ -340,9 +371,15 @@ def test_reset_calibration(mock_motors, dummy_motors):
write_mins_stubs = []
write_maxes_stubs = []
for motor in dummy_motors.values():
write_homing_stubs.append(mock_motors.build_write_stub("Homing_Offset", motor.id, 0))
write_mins_stubs.append(mock_motors.build_write_stub("Min_Position_Limit", motor.id, 0))
write_maxes_stubs.append(mock_motors.build_write_stub("Max_Position_Limit", motor.id, 4095))
write_homing_stubs.append(
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
)
write_mins_stubs.append(
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
)
write_maxes_stubs.append(
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
@@ -372,11 +409,15 @@ def test_set_half_turn_homings(mock_motors, dummy_motors):
2: -2005, # 42 - 2047
3: 1625, # 3672 - 2047
}
read_pos_stub = mock_motors.build_sync_read_stub("Present_Position", current_positions)
read_pos_stub = mock_motors.build_sync_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], current_positions
)
write_homing_stubs = []
for id_, homing in expected_homings.items():
encoded_homing = encode_sign_magnitude(homing, 11)
stub = mock_motors.build_write_stub("Homing_Offset", id_, encoded_homing)
stub = mock_motors.build_write_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing
)
write_homing_stubs.append(stub)
motors_bus = FeetechMotorsBus(
@@ -409,7 +450,9 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors):
"dummy_2": 3600,
"dummy_3": 4002,
}
read_pos_stub = mock_motors.build_sequential_sync_read_stub("Present_Position", positions)
read_pos_stub = mock_motors.build_sequential_sync_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions
)
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
motors_bus = FeetechMotorsBus(
port=mock_motors.port,

View File

@@ -1,9 +1,13 @@
# ruff: noqa: N802
import re
from unittest.mock import patch
import pytest
from lerobot.common.motors.motors_bus import (
Motor,
MotorNormMode,
MotorsBus,
assert_same_address,
get_address,
@@ -14,30 +18,35 @@ DUMMY_CTRL_TABLE_1 = {
"Firmware_Version": (0, 1),
"Model_Number": (1, 2),
"Present_Position": (3, 4),
"Goal_Position": (7, 2),
"Goal_Position": (11, 2),
}
DUMMY_CTRL_TABLE_2 = {
"Model_Number": (0, 2),
"Firmware_Version": (2, 1),
"Present_Position": (3, 4),
"Goal_Position": (7, 4),
"Lock": (7, 4),
"Present_Velocity": (7, 4),
"Goal_Position": (11, 4),
"Goal_Velocity": (15, 4),
"Lock": (19, 1),
}
DUMMY_MODEL_CTRL_TABLE = {
"model_1": DUMMY_CTRL_TABLE_1,
"model_2": DUMMY_CTRL_TABLE_2,
"model_3": DUMMY_CTRL_TABLE_2,
}
DUMMY_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
}
DUMMY_MODEL_BAUDRATE_TABLE = {
"model_1": DUMMY_BAUDRATE_TABLE,
"model_2": DUMMY_BAUDRATE_TABLE,
"model_3": DUMMY_BAUDRATE_TABLE,
}
DUMMY_ENCODING_TABLE = {
@@ -48,21 +57,78 @@ DUMMY_ENCODING_TABLE = {
DUMMY_MODEL_ENCODING_TABLE = {
"model_1": DUMMY_ENCODING_TABLE,
"model_2": DUMMY_ENCODING_TABLE,
"model_3": DUMMY_ENCODING_TABLE,
}
DUMMY_MODEL_NUMBER_TABLE = {
"model_1": 1234,
"model_2": 5678,
"model_3": 5799,
}
DUMMY_MODEL_RESOLUTION_TABLE = {
"model_1": 4096,
"model_2": 1024,
"model_3": 4096,
}
class DummyMotorsBus(MotorsBus):
class MockPortHandler:
def __init__(self, port_name):
self.is_open: bool = False
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool = False
self.port_name: str = port_name
self.ser = None
def openPort(self):
self.is_open = True
return self.is_open
def closePort(self):
self.is_open = False
def clearPort(self): ...
def setPortName(self, port_name):
self.port_name = port_name
def getPortName(self):
return self.port_name
def setBaudRate(self, baudrate):
self.baudrate: baudrate
def getBaudRate(self):
return self.baudrate
def getBytesAvailable(self): ...
def readPort(self, length): ...
def writePort(self, packet): ...
def setPacketTimeout(self, packet_length): ...
def setPacketTimeoutMillis(self, msec): ...
def isPacketTimeout(self): ...
def getCurrentTime(self): ...
def getTimeSinceStart(self): ...
def setupPort(self, cflag_baud): ...
def getCFlagBaud(self, baudrate): ...
class MockMotorsBus(MotorsBus):
available_baudrates = [500_000, 1_000_000]
default_timeout = 1000
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
model_number_table = {"model_1": 1234, "model_2": 5678}
model_resolution_table = {"model_1": 4096, "model_2": 1024}
model_number_table = DUMMY_MODEL_NUMBER_TABLE
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
normalized_data = ["Present_Position", "Goal_Position"]
def __init__(self, port: str, motors: dict[str, Motor]):
super().__init__(port, motors)
self.port_handler = MockPortHandler(port)
def _assert_protocol_is_compatible(self, instruction_name): ...
def configure_motors(self): ...
@@ -75,6 +141,15 @@ class DummyMotorsBus(MotorsBus):
def broadcast_ping(self, num_retry, raise_on_error): ...
@pytest.fixture
def dummy_motors() -> dict[str, Motor]:
return {
"dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100),
"dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100),
"dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100),
}
def test_get_ctrl_table():
model = "model_1"
ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
@@ -105,7 +180,7 @@ def test_assert_same_address():
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position")
def test_assert_same_address_different_addresses():
def test_assert_same_length_different_addresses():
models = ["model_1", "model_2"]
with pytest.raises(
NotImplementedError,
@@ -114,7 +189,7 @@ def test_assert_same_address_different_addresses():
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number")
def test_assert_same_address_different_bytes():
def test_assert_same_address_different_length():
models = ["model_1", "model_2"]
with pytest.raises(
NotImplementedError,
@@ -124,18 +199,267 @@ def test_assert_same_address_different_bytes():
def test__serialize_data_invalid_length():
bus = DummyMotorsBus("", {})
bus = MockMotorsBus("", {})
with pytest.raises(NotImplementedError):
bus._serialize_data(100, 3)
def test__serialize_data_negative_numbers():
bus = DummyMotorsBus("", {})
bus = MockMotorsBus("", {})
with pytest.raises(ValueError):
bus._serialize_data(-1, 1)
def test__serialize_data_large_number():
bus = DummyMotorsBus("", {})
bus = MockMotorsBus("", {})
with pytest.raises(ValueError):
bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Firmware_Version", 1, 14),
("Model_Number", 1, 5678),
("Present_Position", 2, 1337),
("Present_Velocity", 3, 42),
],
)
def test_read(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(assert_motors_exist=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
with (
patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read,
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
):
returned_value = bus.read(data_name, f"dummy_{id_}")
assert returned_value == value
mock__read.assert_called_once_with(
addr,
length,
id_,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, {id_: value})
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Goal_Position", 1, 1337),
("Goal_Velocity", 2, 3682),
("Lock", 3, 1),
],
)
def test_write(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(assert_motors_exist=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
with (
patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write,
patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize,
):
bus.write(data_name, f"dummy_{id_}", value)
mock__write.assert_called_once_with(
addr,
length,
id_,
value,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(data_name, {id_: value})
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Firmware_Version", 1, 14),
("Model_Number", 1, 5678),
("Present_Position", 2, 1337),
("Present_Velocity", 3, 42),
],
)
def test_sync_read_by_str(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(assert_motors_exist=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = [id_]
expected_value = {f"dummy_{id_}": value}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
):
returned_dict = bus.sync_read(data_name, f"dummy_{id_}")
assert returned_dict == expected_value
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, {id_: value})
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Model_Number", {1: 5678}),
("Present_Position", {1: 1337, 2: 42}),
("Present_Velocity", {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test_sync_read_by_list(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(assert_motors_exist=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = list(ids_values)
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
):
returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids])
assert returned_dict == expected_values
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, ids_values)
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Model_Number", {1: 5678, 2: 5799, 3: 5678}),
("Present_Position", {1: 1337, 2: 42, 3: 4016}),
("Goal_Position", {1: 4008, 2: 199, 3: 3446}),
],
ids=["Model_Number", "Present_Position", "Goal_Position"],
)
def test_sync_read_by_none(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(assert_motors_exist=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = list(ids_values)
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
):
returned_dict = bus.sync_read(data_name)
assert returned_dict == expected_values
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, ids_values)
@pytest.mark.parametrize(
"data_name, value",
[
("Goal_Position", 500),
("Goal_Velocity", 4010),
("Lock", 0),
],
)
def test_sync_write_by_single_value(data_name, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(assert_motors_exist=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids_values = {m.id: value for m in dummy_motors.values()}
with (
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
):
bus.sync_write(data_name, value)
mock__sync_write.assert_called_once_with(
addr,
length,
ids_values,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(data_name, ids_values)
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Goal_Position", {1: 1337, 2: 42, 3: 4016}),
("Goal_Velocity", {1: 50, 2: 83, 3: 2777}),
("Lock", {1: 0, 2: 0, 3: 1}),
],
ids=["Goal_Position", "Goal_Velocity", "Lock"],
)
def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(assert_motors_exist=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
):
bus.sync_write(data_name, values)
mock__sync_write.assert_called_once_with(
addr,
length,
ids_values,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(data_name, ids_values)