Merge remote-tracking branch 'origin/user/aliberts/2025_02_25_refactor_robots' into user/aliberts/2025_04_03_add_hope_jr

This commit is contained in:
Simon Alibert
2025-05-19 11:24:57 +02:00
6 changed files with 250 additions and 136 deletions

View File

@@ -140,7 +140,7 @@ class DynamixelMotorsBus(MotorsBus):
def _handshake(self) -> None:
self._assert_motors_exist()
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
model = self.motors[motor].model
search_baudrates = (
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]

View File

@@ -749,7 +749,10 @@ class MotorsBus(abc.ABC):
# Move cursor up to overwrite the previous output
move_cursor_up(len(motors) + 3)
# TODO(Steven, aliberts): add check to ensure mins and maxes are different
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
if same_min_max:
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
return mins, maxes
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:

View File

@@ -16,6 +16,7 @@
import logging
import time
from itertools import chain
from typing import Any
from lerobot.common.cameras.utils import make_cameras_from_configs
@@ -183,6 +184,12 @@ class LeKiwi(Robot):
self.bus.enable_torque()
def setup_motors(self) -> None:
for motor in chain(reversed(self.arm_motors), reversed(self.base_motors)):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")

View File

@@ -0,0 +1,137 @@
# ruff: noqa: N802
from lerobot.common.motors.motors_bus import (
Motor,
MotorsBus,
)
DUMMY_CTRL_TABLE_1 = {
"Firmware_Version": (0, 1),
"Model_Number": (1, 2),
"Present_Position": (3, 4),
"Goal_Position": (11, 2),
}
DUMMY_CTRL_TABLE_2 = {
"Model_Number": (0, 2),
"Firmware_Version": (2, 1),
"Present_Position": (3, 4),
"Present_Velocity": (7, 4),
"Goal_Position": (11, 4),
"Goal_Velocity": (15, 4),
"Lock": (19, 1),
}
DUMMY_MODEL_CTRL_TABLE = {
"model_1": DUMMY_CTRL_TABLE_1,
"model_2": DUMMY_CTRL_TABLE_2,
"model_3": DUMMY_CTRL_TABLE_2,
}
DUMMY_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
}
DUMMY_MODEL_BAUDRATE_TABLE = {
"model_1": DUMMY_BAUDRATE_TABLE,
"model_2": DUMMY_BAUDRATE_TABLE,
"model_3": DUMMY_BAUDRATE_TABLE,
}
DUMMY_ENCODING_TABLE = {
"Present_Position": 8,
"Goal_Position": 10,
}
DUMMY_MODEL_ENCODING_TABLE = {
"model_1": DUMMY_ENCODING_TABLE,
"model_2": DUMMY_ENCODING_TABLE,
"model_3": DUMMY_ENCODING_TABLE,
}
DUMMY_MODEL_NUMBER_TABLE = {
"model_1": 1234,
"model_2": 5678,
"model_3": 5799,
}
DUMMY_MODEL_RESOLUTION_TABLE = {
"model_1": 4096,
"model_2": 1024,
"model_3": 4096,
}
class MockPortHandler:
def __init__(self, port_name):
self.is_open: bool = False
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool = False
self.port_name: str = port_name
self.ser = None
def openPort(self):
self.is_open = True
return self.is_open
def closePort(self):
self.is_open = False
def clearPort(self): ...
def setPortName(self, port_name):
self.port_name = port_name
def getPortName(self):
return self.port_name
def setBaudRate(self, baudrate):
self.baudrate: baudrate
def getBaudRate(self):
return self.baudrate
def getBytesAvailable(self): ...
def readPort(self, length): ...
def writePort(self, packet): ...
def setPacketTimeout(self, packet_length): ...
def setPacketTimeoutMillis(self, msec): ...
def isPacketTimeout(self): ...
def getCurrentTime(self): ...
def getTimeSinceStart(self): ...
def setupPort(self, cflag_baud): ...
def getCFlagBaud(self, baudrate): ...
class MockMotorsBus(MotorsBus):
available_baudrates = [500_000, 1_000_000]
default_timeout = 1000
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
model_number_table = DUMMY_MODEL_NUMBER_TABLE
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
normalized_data = ["Present_Position", "Goal_Position"]
def __init__(self, port: str, motors: dict[str, Motor]):
super().__init__(port, motors)
self.port_handler = MockPortHandler(port)
def _assert_protocol_is_compatible(self, instruction_name): ...
def _handshake(self): ...
def _find_single_motor(self, motor, initial_baudrate): ...
def configure_motors(self): ...
def read_calibration(self): ...
def write_calibration(self, calibration_dict): ...
def disable_torque(self, motors, num_retry): ...
def _disable_torque(self, motor, model, num_retry): ...
def enable_torque(self, motors, num_retry): ...
def _get_half_turn_homings(self, positions): ...
def _encode_sign(self, data_name, ids_values): ...
def _decode_sign(self, data_name, ids_values): ...
def _split_into_byte_chunks(self, value, length): ...
def broadcast_ping(self, num_retry, raise_on_error): ...

View File

@@ -1,5 +1,3 @@
# ruff: noqa: N802
import re
from unittest.mock import patch
@@ -8,142 +6,16 @@ import pytest
from lerobot.common.motors.motors_bus import (
Motor,
MotorNormMode,
MotorsBus,
assert_same_address,
get_address,
get_ctrl_table,
)
DUMMY_CTRL_TABLE_1 = {
"Firmware_Version": (0, 1),
"Model_Number": (1, 2),
"Present_Position": (3, 4),
"Goal_Position": (11, 2),
}
DUMMY_CTRL_TABLE_2 = {
"Model_Number": (0, 2),
"Firmware_Version": (2, 1),
"Present_Position": (3, 4),
"Present_Velocity": (7, 4),
"Goal_Position": (11, 4),
"Goal_Velocity": (15, 4),
"Lock": (19, 1),
}
DUMMY_MODEL_CTRL_TABLE = {
"model_1": DUMMY_CTRL_TABLE_1,
"model_2": DUMMY_CTRL_TABLE_2,
"model_3": DUMMY_CTRL_TABLE_2,
}
DUMMY_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
}
DUMMY_MODEL_BAUDRATE_TABLE = {
"model_1": DUMMY_BAUDRATE_TABLE,
"model_2": DUMMY_BAUDRATE_TABLE,
"model_3": DUMMY_BAUDRATE_TABLE,
}
DUMMY_ENCODING_TABLE = {
"Present_Position": 8,
"Goal_Position": 10,
}
DUMMY_MODEL_ENCODING_TABLE = {
"model_1": DUMMY_ENCODING_TABLE,
"model_2": DUMMY_ENCODING_TABLE,
"model_3": DUMMY_ENCODING_TABLE,
}
DUMMY_MODEL_NUMBER_TABLE = {
"model_1": 1234,
"model_2": 5678,
"model_3": 5799,
}
DUMMY_MODEL_RESOLUTION_TABLE = {
"model_1": 4096,
"model_2": 1024,
"model_3": 4096,
}
class MockPortHandler:
def __init__(self, port_name):
self.is_open: bool = False
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool = False
self.port_name: str = port_name
self.ser = None
def openPort(self):
self.is_open = True
return self.is_open
def closePort(self):
self.is_open = False
def clearPort(self): ...
def setPortName(self, port_name):
self.port_name = port_name
def getPortName(self):
return self.port_name
def setBaudRate(self, baudrate):
self.baudrate: baudrate
def getBaudRate(self):
return self.baudrate
def getBytesAvailable(self): ...
def readPort(self, length): ...
def writePort(self, packet): ...
def setPacketTimeout(self, packet_length): ...
def setPacketTimeoutMillis(self, msec): ...
def isPacketTimeout(self): ...
def getCurrentTime(self): ...
def getTimeSinceStart(self): ...
def setupPort(self, cflag_baud): ...
def getCFlagBaud(self, baudrate): ...
class MockMotorsBus(MotorsBus):
available_baudrates = [500_000, 1_000_000]
default_timeout = 1000
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
model_number_table = DUMMY_MODEL_NUMBER_TABLE
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
normalized_data = ["Present_Position", "Goal_Position"]
def __init__(self, port: str, motors: dict[str, Motor]):
super().__init__(port, motors)
self.port_handler = MockPortHandler(port)
def _assert_protocol_is_compatible(self, instruction_name): ...
def _handshake(self): ...
def _find_single_motor(self, motor, initial_baudrate): ...
def configure_motors(self): ...
def read_calibration(self): ...
def write_calibration(self, calibration_dict): ...
def disable_torque(self, motors, num_retry): ...
def _disable_torque(self, motor, model, num_retry): ...
def enable_torque(self, motors, num_retry): ...
def _get_half_turn_homings(self, positions): ...
def _encode_sign(self, data_name, ids_values): ...
def _decode_sign(self, data_name, ids_values): ...
def _split_into_byte_chunks(self, value, length): ...
def broadcast_ping(self, num_retry, raise_on_error): ...
from tests.mocks.mock_motors_bus import (
DUMMY_CTRL_TABLE_1,
DUMMY_CTRL_TABLE_2,
DUMMY_MODEL_CTRL_TABLE,
MockMotorsBus,
)
@pytest.fixture

View File

@@ -0,0 +1,95 @@
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from lerobot.common.robots.so100_follower import (
SO100Follower,
SO100FollowerConfig,
)
def _make_bus_mock() -> MagicMock:
"""Return a bus mock with just the attributes used by the robot."""
bus = MagicMock(name="FeetechBusMock")
bus.is_connected = False
def _connect():
bus.is_connected = True
def _disconnect(_disable=True):
bus.is_connected = False
bus.connect.side_effect = _connect
bus.disconnect.side_effect = _disconnect
@contextmanager
def _dummy_cm():
yield
bus.torque_disabled.side_effect = _dummy_cm
return bus
@pytest.fixture
def follower():
bus_mock = _make_bus_mock()
def _bus_side_effect(*_args, **kwargs):
bus_mock.motors = kwargs["motors"]
motors_order: list[str] = list(bus_mock.motors)
bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)}
bus_mock.sync_write.return_value = None
bus_mock.write.return_value = None
bus_mock.disable_torque.return_value = None
bus_mock.enable_torque.return_value = None
bus_mock.is_calibrated = True
return bus_mock
with (
patch(
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
side_effect=_bus_side_effect,
),
patch.object(SO100Follower, "configure", lambda self: None),
):
cfg = SO100FollowerConfig(port="/dev/null")
robot = SO100Follower(cfg)
yield robot
if robot.is_connected:
robot.disconnect()
def test_connect_disconnect(follower):
assert not follower.is_connected
follower.connect()
assert follower.is_connected
follower.disconnect()
assert not follower.is_connected
def test_get_observation(follower):
follower.connect()
obs = follower.get_observation()
expected_keys = {f"{m}.pos" for m in follower.bus.motors}
assert set(obs.keys()) == expected_keys
for idx, motor in enumerate(follower.bus.motors, 1):
assert obs[f"{motor}.pos"] == idx
def test_send_action(follower):
follower.connect()
action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)}
returned = follower.send_action(action)
assert returned == action
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)}
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)