Fix test
This commit is contained in:
@@ -8,8 +8,6 @@ from lerobot.common.robots.so100_follower import (
|
||||
SO100FollowerConfig,
|
||||
)
|
||||
|
||||
_MOTORS = SO100Follower(SO100FollowerConfig("")).bus.motors
|
||||
|
||||
|
||||
def _make_bus_mock() -> MagicMock:
|
||||
"""Return a bus mock with just the attributes used by the robot."""
|
||||
@@ -24,13 +22,6 @@ def _make_bus_mock() -> MagicMock:
|
||||
|
||||
bus.connect.side_effect = _connect
|
||||
bus.disconnect.side_effect = _disconnect
|
||||
bus.motors = _MOTORS
|
||||
bus.is_calibrated = True
|
||||
bus.sync_read.return_value = {m: i for i, m in enumerate(_MOTORS, 1)}
|
||||
bus.sync_write.return_value = None
|
||||
bus.write.return_value = None
|
||||
bus.disable_torque.return_value = None
|
||||
bus.enable_torque.return_value = None
|
||||
|
||||
@contextmanager
|
||||
def _dummy_cm():
|
||||
@@ -43,10 +34,24 @@ def _make_bus_mock() -> MagicMock:
|
||||
|
||||
@pytest.fixture
|
||||
def follower():
|
||||
bus_mock = _make_bus_mock()
|
||||
|
||||
def _bus_side_effect(*_args, **kwargs):
|
||||
bus_mock.motors = kwargs["motors"]
|
||||
motors_order: list[str] = list(bus_mock.motors)
|
||||
|
||||
bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)}
|
||||
bus_mock.sync_write.return_value = None
|
||||
bus_mock.write.return_value = None
|
||||
bus_mock.disable_torque.return_value = None
|
||||
bus_mock.enable_torque.return_value = None
|
||||
bus_mock.is_calibrated = True
|
||||
return bus_mock
|
||||
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
|
||||
return_value=_make_bus_mock(),
|
||||
side_effect=_bus_side_effect,
|
||||
),
|
||||
patch.object(SO100Follower, "configure", lambda self: None),
|
||||
):
|
||||
@@ -71,20 +76,20 @@ def test_get_observation(follower):
|
||||
follower.connect()
|
||||
obs = follower.get_observation()
|
||||
|
||||
expected_keys = {f"{m}.pos" for m in _MOTORS}
|
||||
expected_keys = {f"{m}.pos" for m in follower.bus.motors}
|
||||
assert set(obs.keys()) == expected_keys
|
||||
|
||||
for idx, motor in enumerate(_MOTORS, 1):
|
||||
for idx, motor in enumerate(follower.bus.motors, 1):
|
||||
assert obs[f"{motor}.pos"] == idx
|
||||
|
||||
|
||||
def test_send_action(follower):
|
||||
follower.connect()
|
||||
|
||||
action = {f"{m}.pos": i * 10 for i, m in enumerate(_MOTORS, 1)}
|
||||
action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)}
|
||||
returned = follower.send_action(action)
|
||||
|
||||
assert returned == action
|
||||
|
||||
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(_MOTORS)}
|
||||
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)}
|
||||
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)
|
||||
|
||||
Reference in New Issue
Block a user