This commit is contained in:
Simon Alibert
2025-05-19 11:24:10 +02:00
parent edbba48e81
commit 05dfa26c54

View File

@@ -8,8 +8,6 @@ from lerobot.common.robots.so100_follower import (
SO100FollowerConfig,
)
_MOTORS = SO100Follower(SO100FollowerConfig("")).bus.motors
def _make_bus_mock() -> MagicMock:
"""Return a bus mock with just the attributes used by the robot."""
@@ -24,13 +22,6 @@ def _make_bus_mock() -> MagicMock:
bus.connect.side_effect = _connect
bus.disconnect.side_effect = _disconnect
bus.motors = _MOTORS
bus.is_calibrated = True
bus.sync_read.return_value = {m: i for i, m in enumerate(_MOTORS, 1)}
bus.sync_write.return_value = None
bus.write.return_value = None
bus.disable_torque.return_value = None
bus.enable_torque.return_value = None
@contextmanager
def _dummy_cm():
@@ -43,10 +34,24 @@ def _make_bus_mock() -> MagicMock:
@pytest.fixture
def follower():
bus_mock = _make_bus_mock()
def _bus_side_effect(*_args, **kwargs):
bus_mock.motors = kwargs["motors"]
motors_order: list[str] = list(bus_mock.motors)
bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)}
bus_mock.sync_write.return_value = None
bus_mock.write.return_value = None
bus_mock.disable_torque.return_value = None
bus_mock.enable_torque.return_value = None
bus_mock.is_calibrated = True
return bus_mock
with (
patch(
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
return_value=_make_bus_mock(),
side_effect=_bus_side_effect,
),
patch.object(SO100Follower, "configure", lambda self: None),
):
@@ -71,20 +76,20 @@ def test_get_observation(follower):
follower.connect()
obs = follower.get_observation()
expected_keys = {f"{m}.pos" for m in _MOTORS}
expected_keys = {f"{m}.pos" for m in follower.bus.motors}
assert set(obs.keys()) == expected_keys
for idx, motor in enumerate(_MOTORS, 1):
for idx, motor in enumerate(follower.bus.motors, 1):
assert obs[f"{motor}.pos"] == idx
def test_send_action(follower):
follower.connect()
action = {f"{m}.pos": i * 10 for i, m in enumerate(_MOTORS, 1)}
action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)}
returned = follower.send_action(action)
assert returned == action
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(_MOTORS)}
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)}
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)