Simplify motors mocks

This commit is contained in:
Simon Alibert
2025-04-03 16:43:23 +02:00
parent 4679725957
commit 0dcb2caba8
2 changed files with 30 additions and 92 deletions

View File

@@ -1,5 +1,4 @@
import abc
import random
from typing import Callable
import scservo_sdk as scs
@@ -248,40 +247,19 @@ class MockStatusPacket(MockFeetechPacket):
return cls.build(scs_id, params=[], length=2, error=error)
@classmethod
def present_position(cls, scs_id: int, pos: int | None = None, min_max_range: tuple = (0, 4095)) -> bytes:
"""Builds a 'Present_Position' status packet.
def read(cls, scs_id: int, value: int, param_length: int) -> bytes:
"""Builds a 'Read' status packet.
Args:
scs_id (int): List of the servos ids.
pos (int | None, optional): Desired 'Present_Position' to be returned in the packet. If None, it
will use a random value in the min_max_range. Defaults to None.
min_max_range (tuple, optional): Min/max range to generate the position values used for when 'pos'
is None. Note that the bounds are included in the range. Defaults to (0, 4095).
scs_id (int): ID of the servo responding.
value (int): Desired value to be returned in the packet.
param_length (int): The address length as reported in the control table.
Returns:
bytes: The raw 'Present_Position' status packet ready to be sent through serial.
bytes: The raw 'Sync Read' status packet ready to be sent through serial.
"""
pos = random.randint(*min_max_range) if pos is None else pos
params = [scs.SCS_LOBYTE(pos), scs.SCS_HIBYTE(pos)]
length = 4
return cls.build(scs_id, params=params, length=length)
@classmethod
def model_number(cls, scs_id: int, model_nb: int | None = None) -> bytes:
"""Builds a 'Present_Position' status packet.
Args:
scs_id (int): List of the servos ids.
pos (int | None, optional): Desired 'Present_Position' to be returned in the packet. If None, it
will use a random value in the min_max_range. Defaults to None.
min_max_range (tuple, optional): Min/max range to generate the position values used for when 'pos'
is None. Note that the bounds are included in the range. Defaults to (0, 4095).
Returns:
bytes: The raw 'Present_Position' status packet ready to be sent through serial.
"""
params = [scs.SCS_LOBYTE(model_nb), scs.SCS_HIBYTE(model_nb)]
length = 4
params = FeetechMotorsBus._split_int_to_bytes(value, param_length)
length = param_length + 2
return cls.build(scs_id, params=params, length=length)
@@ -368,7 +346,7 @@ class MockMotors(MockSerial):
raise NotImplementedError
address, length = self.ctrl_table[data_name]
read_request = MockInstructionPacket.read(scs_id, address, length)
return_packet = MockStatusPacket.model_number(scs_id, value)
return_packet = MockStatusPacket.read(scs_id, value, length)
read_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Read_{data_name}_{scs_id}"
self.stub(
@@ -381,23 +359,9 @@ class MockMotors(MockSerial):
def build_sync_read_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
) -> str:
"""
'data_name' supported:
- Present_Position
- Model_Number
"""
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
if data_name == "Present_Position":
return_packets = b"".join(
MockStatusPacket.present_position(id_, pos) for id_, pos in ids_values.items()
)
elif data_name == "Model_Number":
return_packets = b"".join(
MockStatusPacket.model_number(id_, model_nb) for id_, model_nb in ids_values.items()
)
else:
raise NotImplementedError
return_packets = b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
stub_name = f"Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
@@ -411,22 +375,14 @@ class MockMotors(MockSerial):
def build_sequential_sync_read_stub(
self, data_name: str, ids_values: dict[int, list[int]] | None = None
) -> str:
"""
'data_name' supported:
- Present_Position
"""
sequence_length = len(next(iter(ids_values.values())))
assert all(len(positions) == sequence_length for positions in ids_values.values())
if data_name != "Present_Position":
raise NotImplementedError
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
sequential_packets = []
for count in range(sequence_length):
return_packets = b"".join(
MockStatusPacket.present_position(id_, positions[count])
for id_, positions in ids_values.items()
MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items()
)
sequential_packets.append(return_packets)