Merge remote-tracking branch 'origin/user/aliberts/2025_02_25_refactor_robots' into user/aliberts/2025_04_03_add_hope_jr
This commit is contained in:
@@ -140,7 +140,7 @@ class DynamixelMotorsBus(MotorsBus):
|
|||||||
def _handshake(self) -> None:
|
def _handshake(self) -> None:
|
||||||
self._assert_motors_exist()
|
self._assert_motors_exist()
|
||||||
|
|
||||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
|
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||||
model = self.motors[motor].model
|
model = self.motors[motor].model
|
||||||
search_baudrates = (
|
search_baudrates = (
|
||||||
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
|
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
|
||||||
|
|||||||
@@ -749,7 +749,10 @@ class MotorsBus(abc.ABC):
|
|||||||
# Move cursor up to overwrite the previous output
|
# Move cursor up to overwrite the previous output
|
||||||
move_cursor_up(len(motors) + 3)
|
move_cursor_up(len(motors) + 3)
|
||||||
|
|
||||||
# TODO(Steven, aliberts): add check to ensure mins and maxes are different
|
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
|
||||||
|
if same_min_max:
|
||||||
|
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
|
||||||
|
|
||||||
return mins, maxes
|
return mins, maxes
|
||||||
|
|
||||||
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:
|
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from itertools import chain
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||||
@@ -183,6 +184,12 @@ class LeKiwi(Robot):
|
|||||||
|
|
||||||
self.bus.enable_torque()
|
self.bus.enable_torque()
|
||||||
|
|
||||||
|
def setup_motors(self) -> None:
|
||||||
|
for motor in chain(reversed(self.arm_motors), reversed(self.base_motors)):
|
||||||
|
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||||
|
self.bus.setup_motor(motor)
|
||||||
|
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, Any]:
|
def get_observation(self) -> dict[str, Any]:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|||||||
137
tests/mocks/mock_motors_bus.py
Normal file
137
tests/mocks/mock_motors_bus.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# ruff: noqa: N802
|
||||||
|
|
||||||
|
from lerobot.common.motors.motors_bus import (
|
||||||
|
Motor,
|
||||||
|
MotorsBus,
|
||||||
|
)
|
||||||
|
|
||||||
|
DUMMY_CTRL_TABLE_1 = {
|
||||||
|
"Firmware_Version": (0, 1),
|
||||||
|
"Model_Number": (1, 2),
|
||||||
|
"Present_Position": (3, 4),
|
||||||
|
"Goal_Position": (11, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_CTRL_TABLE_2 = {
|
||||||
|
"Model_Number": (0, 2),
|
||||||
|
"Firmware_Version": (2, 1),
|
||||||
|
"Present_Position": (3, 4),
|
||||||
|
"Present_Velocity": (7, 4),
|
||||||
|
"Goal_Position": (11, 4),
|
||||||
|
"Goal_Velocity": (15, 4),
|
||||||
|
"Lock": (19, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_CTRL_TABLE = {
|
||||||
|
"model_1": DUMMY_CTRL_TABLE_1,
|
||||||
|
"model_2": DUMMY_CTRL_TABLE_2,
|
||||||
|
"model_3": DUMMY_CTRL_TABLE_2,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_BAUDRATE_TABLE = {
|
||||||
|
0: 1_000_000,
|
||||||
|
1: 500_000,
|
||||||
|
2: 250_000,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_BAUDRATE_TABLE = {
|
||||||
|
"model_1": DUMMY_BAUDRATE_TABLE,
|
||||||
|
"model_2": DUMMY_BAUDRATE_TABLE,
|
||||||
|
"model_3": DUMMY_BAUDRATE_TABLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_ENCODING_TABLE = {
|
||||||
|
"Present_Position": 8,
|
||||||
|
"Goal_Position": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_ENCODING_TABLE = {
|
||||||
|
"model_1": DUMMY_ENCODING_TABLE,
|
||||||
|
"model_2": DUMMY_ENCODING_TABLE,
|
||||||
|
"model_3": DUMMY_ENCODING_TABLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_NUMBER_TABLE = {
|
||||||
|
"model_1": 1234,
|
||||||
|
"model_2": 5678,
|
||||||
|
"model_3": 5799,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_RESOLUTION_TABLE = {
|
||||||
|
"model_1": 4096,
|
||||||
|
"model_2": 1024,
|
||||||
|
"model_3": 4096,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MockPortHandler:
|
||||||
|
def __init__(self, port_name):
|
||||||
|
self.is_open: bool = False
|
||||||
|
self.baudrate: int
|
||||||
|
self.packet_start_time: float
|
||||||
|
self.packet_timeout: float
|
||||||
|
self.tx_time_per_byte: float
|
||||||
|
self.is_using: bool = False
|
||||||
|
self.port_name: str = port_name
|
||||||
|
self.ser = None
|
||||||
|
|
||||||
|
def openPort(self):
|
||||||
|
self.is_open = True
|
||||||
|
return self.is_open
|
||||||
|
|
||||||
|
def closePort(self):
|
||||||
|
self.is_open = False
|
||||||
|
|
||||||
|
def clearPort(self): ...
|
||||||
|
def setPortName(self, port_name):
|
||||||
|
self.port_name = port_name
|
||||||
|
|
||||||
|
def getPortName(self):
|
||||||
|
return self.port_name
|
||||||
|
|
||||||
|
def setBaudRate(self, baudrate):
|
||||||
|
self.baudrate: baudrate
|
||||||
|
|
||||||
|
def getBaudRate(self):
|
||||||
|
return self.baudrate
|
||||||
|
|
||||||
|
def getBytesAvailable(self): ...
|
||||||
|
def readPort(self, length): ...
|
||||||
|
def writePort(self, packet): ...
|
||||||
|
def setPacketTimeout(self, packet_length): ...
|
||||||
|
def setPacketTimeoutMillis(self, msec): ...
|
||||||
|
def isPacketTimeout(self): ...
|
||||||
|
def getCurrentTime(self): ...
|
||||||
|
def getTimeSinceStart(self): ...
|
||||||
|
def setupPort(self, cflag_baud): ...
|
||||||
|
def getCFlagBaud(self, baudrate): ...
|
||||||
|
|
||||||
|
|
||||||
|
class MockMotorsBus(MotorsBus):
|
||||||
|
available_baudrates = [500_000, 1_000_000]
|
||||||
|
default_timeout = 1000
|
||||||
|
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
|
||||||
|
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
|
||||||
|
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
|
||||||
|
model_number_table = DUMMY_MODEL_NUMBER_TABLE
|
||||||
|
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
|
||||||
|
normalized_data = ["Present_Position", "Goal_Position"]
|
||||||
|
|
||||||
|
def __init__(self, port: str, motors: dict[str, Motor]):
|
||||||
|
super().__init__(port, motors)
|
||||||
|
self.port_handler = MockPortHandler(port)
|
||||||
|
|
||||||
|
def _assert_protocol_is_compatible(self, instruction_name): ...
|
||||||
|
def _handshake(self): ...
|
||||||
|
def _find_single_motor(self, motor, initial_baudrate): ...
|
||||||
|
def configure_motors(self): ...
|
||||||
|
def read_calibration(self): ...
|
||||||
|
def write_calibration(self, calibration_dict): ...
|
||||||
|
def disable_torque(self, motors, num_retry): ...
|
||||||
|
def _disable_torque(self, motor, model, num_retry): ...
|
||||||
|
def enable_torque(self, motors, num_retry): ...
|
||||||
|
def _get_half_turn_homings(self, positions): ...
|
||||||
|
def _encode_sign(self, data_name, ids_values): ...
|
||||||
|
def _decode_sign(self, data_name, ids_values): ...
|
||||||
|
def _split_into_byte_chunks(self, value, length): ...
|
||||||
|
def broadcast_ping(self, num_retry, raise_on_error): ...
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
# ruff: noqa: N802
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@@ -8,142 +6,16 @@ import pytest
|
|||||||
from lerobot.common.motors.motors_bus import (
|
from lerobot.common.motors.motors_bus import (
|
||||||
Motor,
|
Motor,
|
||||||
MotorNormMode,
|
MotorNormMode,
|
||||||
MotorsBus,
|
|
||||||
assert_same_address,
|
assert_same_address,
|
||||||
get_address,
|
get_address,
|
||||||
get_ctrl_table,
|
get_ctrl_table,
|
||||||
)
|
)
|
||||||
|
from tests.mocks.mock_motors_bus import (
|
||||||
DUMMY_CTRL_TABLE_1 = {
|
DUMMY_CTRL_TABLE_1,
|
||||||
"Firmware_Version": (0, 1),
|
DUMMY_CTRL_TABLE_2,
|
||||||
"Model_Number": (1, 2),
|
DUMMY_MODEL_CTRL_TABLE,
|
||||||
"Present_Position": (3, 4),
|
MockMotorsBus,
|
||||||
"Goal_Position": (11, 2),
|
)
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_CTRL_TABLE_2 = {
|
|
||||||
"Model_Number": (0, 2),
|
|
||||||
"Firmware_Version": (2, 1),
|
|
||||||
"Present_Position": (3, 4),
|
|
||||||
"Present_Velocity": (7, 4),
|
|
||||||
"Goal_Position": (11, 4),
|
|
||||||
"Goal_Velocity": (15, 4),
|
|
||||||
"Lock": (19, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_CTRL_TABLE = {
|
|
||||||
"model_1": DUMMY_CTRL_TABLE_1,
|
|
||||||
"model_2": DUMMY_CTRL_TABLE_2,
|
|
||||||
"model_3": DUMMY_CTRL_TABLE_2,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_BAUDRATE_TABLE = {
|
|
||||||
0: 1_000_000,
|
|
||||||
1: 500_000,
|
|
||||||
2: 250_000,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_BAUDRATE_TABLE = {
|
|
||||||
"model_1": DUMMY_BAUDRATE_TABLE,
|
|
||||||
"model_2": DUMMY_BAUDRATE_TABLE,
|
|
||||||
"model_3": DUMMY_BAUDRATE_TABLE,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_ENCODING_TABLE = {
|
|
||||||
"Present_Position": 8,
|
|
||||||
"Goal_Position": 10,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_ENCODING_TABLE = {
|
|
||||||
"model_1": DUMMY_ENCODING_TABLE,
|
|
||||||
"model_2": DUMMY_ENCODING_TABLE,
|
|
||||||
"model_3": DUMMY_ENCODING_TABLE,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_NUMBER_TABLE = {
|
|
||||||
"model_1": 1234,
|
|
||||||
"model_2": 5678,
|
|
||||||
"model_3": 5799,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_RESOLUTION_TABLE = {
|
|
||||||
"model_1": 4096,
|
|
||||||
"model_2": 1024,
|
|
||||||
"model_3": 4096,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MockPortHandler:
|
|
||||||
def __init__(self, port_name):
|
|
||||||
self.is_open: bool = False
|
|
||||||
self.baudrate: int
|
|
||||||
self.packet_start_time: float
|
|
||||||
self.packet_timeout: float
|
|
||||||
self.tx_time_per_byte: float
|
|
||||||
self.is_using: bool = False
|
|
||||||
self.port_name: str = port_name
|
|
||||||
self.ser = None
|
|
||||||
|
|
||||||
def openPort(self):
|
|
||||||
self.is_open = True
|
|
||||||
return self.is_open
|
|
||||||
|
|
||||||
def closePort(self):
|
|
||||||
self.is_open = False
|
|
||||||
|
|
||||||
def clearPort(self): ...
|
|
||||||
def setPortName(self, port_name):
|
|
||||||
self.port_name = port_name
|
|
||||||
|
|
||||||
def getPortName(self):
|
|
||||||
return self.port_name
|
|
||||||
|
|
||||||
def setBaudRate(self, baudrate):
|
|
||||||
self.baudrate: baudrate
|
|
||||||
|
|
||||||
def getBaudRate(self):
|
|
||||||
return self.baudrate
|
|
||||||
|
|
||||||
def getBytesAvailable(self): ...
|
|
||||||
def readPort(self, length): ...
|
|
||||||
def writePort(self, packet): ...
|
|
||||||
def setPacketTimeout(self, packet_length): ...
|
|
||||||
def setPacketTimeoutMillis(self, msec): ...
|
|
||||||
def isPacketTimeout(self): ...
|
|
||||||
def getCurrentTime(self): ...
|
|
||||||
def getTimeSinceStart(self): ...
|
|
||||||
def setupPort(self, cflag_baud): ...
|
|
||||||
def getCFlagBaud(self, baudrate): ...
|
|
||||||
|
|
||||||
|
|
||||||
class MockMotorsBus(MotorsBus):
|
|
||||||
available_baudrates = [500_000, 1_000_000]
|
|
||||||
default_timeout = 1000
|
|
||||||
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
|
|
||||||
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
|
|
||||||
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
|
|
||||||
model_number_table = DUMMY_MODEL_NUMBER_TABLE
|
|
||||||
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
|
|
||||||
normalized_data = ["Present_Position", "Goal_Position"]
|
|
||||||
|
|
||||||
def __init__(self, port: str, motors: dict[str, Motor]):
|
|
||||||
super().__init__(port, motors)
|
|
||||||
self.port_handler = MockPortHandler(port)
|
|
||||||
|
|
||||||
def _assert_protocol_is_compatible(self, instruction_name): ...
|
|
||||||
def _handshake(self): ...
|
|
||||||
def _find_single_motor(self, motor, initial_baudrate): ...
|
|
||||||
def configure_motors(self): ...
|
|
||||||
def read_calibration(self): ...
|
|
||||||
def write_calibration(self, calibration_dict): ...
|
|
||||||
def disable_torque(self, motors, num_retry): ...
|
|
||||||
def _disable_torque(self, motor, model, num_retry): ...
|
|
||||||
def enable_torque(self, motors, num_retry): ...
|
|
||||||
def _get_half_turn_homings(self, positions): ...
|
|
||||||
def _encode_sign(self, data_name, ids_values): ...
|
|
||||||
def _decode_sign(self, data_name, ids_values): ...
|
|
||||||
def _split_into_byte_chunks(self, value, length): ...
|
|
||||||
def broadcast_ping(self, num_retry, raise_on_error): ...
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
95
tests/robots/test_so100_follower.py
Normal file
95
tests/robots/test_so100_follower.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.robots.so100_follower import (
|
||||||
|
SO100Follower,
|
||||||
|
SO100FollowerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_bus_mock() -> MagicMock:
|
||||||
|
"""Return a bus mock with just the attributes used by the robot."""
|
||||||
|
bus = MagicMock(name="FeetechBusMock")
|
||||||
|
bus.is_connected = False
|
||||||
|
|
||||||
|
def _connect():
|
||||||
|
bus.is_connected = True
|
||||||
|
|
||||||
|
def _disconnect(_disable=True):
|
||||||
|
bus.is_connected = False
|
||||||
|
|
||||||
|
bus.connect.side_effect = _connect
|
||||||
|
bus.disconnect.side_effect = _disconnect
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _dummy_cm():
|
||||||
|
yield
|
||||||
|
|
||||||
|
bus.torque_disabled.side_effect = _dummy_cm
|
||||||
|
|
||||||
|
return bus
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def follower():
|
||||||
|
bus_mock = _make_bus_mock()
|
||||||
|
|
||||||
|
def _bus_side_effect(*_args, **kwargs):
|
||||||
|
bus_mock.motors = kwargs["motors"]
|
||||||
|
motors_order: list[str] = list(bus_mock.motors)
|
||||||
|
|
||||||
|
bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)}
|
||||||
|
bus_mock.sync_write.return_value = None
|
||||||
|
bus_mock.write.return_value = None
|
||||||
|
bus_mock.disable_torque.return_value = None
|
||||||
|
bus_mock.enable_torque.return_value = None
|
||||||
|
bus_mock.is_calibrated = True
|
||||||
|
return bus_mock
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
|
||||||
|
side_effect=_bus_side_effect,
|
||||||
|
),
|
||||||
|
patch.object(SO100Follower, "configure", lambda self: None),
|
||||||
|
):
|
||||||
|
cfg = SO100FollowerConfig(port="/dev/null")
|
||||||
|
robot = SO100Follower(cfg)
|
||||||
|
yield robot
|
||||||
|
if robot.is_connected:
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_disconnect(follower):
|
||||||
|
assert not follower.is_connected
|
||||||
|
|
||||||
|
follower.connect()
|
||||||
|
assert follower.is_connected
|
||||||
|
|
||||||
|
follower.disconnect()
|
||||||
|
assert not follower.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_observation(follower):
|
||||||
|
follower.connect()
|
||||||
|
obs = follower.get_observation()
|
||||||
|
|
||||||
|
expected_keys = {f"{m}.pos" for m in follower.bus.motors}
|
||||||
|
assert set(obs.keys()) == expected_keys
|
||||||
|
|
||||||
|
for idx, motor in enumerate(follower.bus.motors, 1):
|
||||||
|
assert obs[f"{motor}.pos"] == idx
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_action(follower):
|
||||||
|
follower.connect()
|
||||||
|
|
||||||
|
action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)}
|
||||||
|
returned = follower.send_action(action)
|
||||||
|
|
||||||
|
assert returned == action
|
||||||
|
|
||||||
|
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)}
|
||||||
|
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)
|
||||||
Reference in New Issue
Block a user