From 05dfa26c541542818932536c860214de0775dee0 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 19 May 2025 11:24:10 +0200 Subject: [PATCH] Fix test --- tests/robots/test_so100_follower.py | 33 +++++++++++++++++------------ 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/tests/robots/test_so100_follower.py b/tests/robots/test_so100_follower.py index 75c8b993..81d9d6a9 100644 --- a/tests/robots/test_so100_follower.py +++ b/tests/robots/test_so100_follower.py @@ -8,8 +8,6 @@ from lerobot.common.robots.so100_follower import ( SO100FollowerConfig, ) -_MOTORS = SO100Follower(SO100FollowerConfig("")).bus.motors - def _make_bus_mock() -> MagicMock: """Return a bus mock with just the attributes used by the robot.""" @@ -24,13 +22,6 @@ def _make_bus_mock() -> MagicMock: bus.connect.side_effect = _connect bus.disconnect.side_effect = _disconnect - bus.motors = _MOTORS - bus.is_calibrated = True - bus.sync_read.return_value = {m: i for i, m in enumerate(_MOTORS, 1)} - bus.sync_write.return_value = None - bus.write.return_value = None - bus.disable_torque.return_value = None - bus.enable_torque.return_value = None @contextmanager def _dummy_cm(): @@ -43,10 +34,24 @@ def _make_bus_mock() -> MagicMock: @pytest.fixture def follower(): + bus_mock = _make_bus_mock() + + def _bus_side_effect(*_args, **kwargs): + bus_mock.motors = kwargs["motors"] + motors_order: list[str] = list(bus_mock.motors) + + bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)} + bus_mock.sync_write.return_value = None + bus_mock.write.return_value = None + bus_mock.disable_torque.return_value = None + bus_mock.enable_torque.return_value = None + bus_mock.is_calibrated = True + return bus_mock + with ( patch( "lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus", - return_value=_make_bus_mock(), + side_effect=_bus_side_effect, ), patch.object(SO100Follower, "configure", lambda self: None), ): @@ -71,20 +76,20 @@ def test_get_observation(follower): follower.connect() obs = follower.get_observation() - expected_keys = {f"{m}.pos" for m in _MOTORS} + expected_keys = {f"{m}.pos" for m in follower.bus.motors} assert set(obs.keys()) == expected_keys - for idx, motor in enumerate(_MOTORS, 1): + for idx, motor in enumerate(follower.bus.motors, 1): assert obs[f"{motor}.pos"] == idx def test_send_action(follower): follower.connect() - action = {f"{m}.pos": i * 10 for i, m in enumerate(_MOTORS, 1)} + action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)} returned = follower.send_action(action) assert returned == action - goal_pos = {m: (i + 1) * 10 for i, m in enumerate(_MOTORS)} + goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)} follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)