Fix test
This commit is contained in:
@@ -8,8 +8,6 @@ from lerobot.common.robots.so100_follower import (
|
|||||||
SO100FollowerConfig,
|
SO100FollowerConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
_MOTORS = SO100Follower(SO100FollowerConfig("")).bus.motors
|
|
||||||
|
|
||||||
|
|
||||||
def _make_bus_mock() -> MagicMock:
|
def _make_bus_mock() -> MagicMock:
|
||||||
"""Return a bus mock with just the attributes used by the robot."""
|
"""Return a bus mock with just the attributes used by the robot."""
|
||||||
@@ -24,13 +22,6 @@ def _make_bus_mock() -> MagicMock:
|
|||||||
|
|
||||||
bus.connect.side_effect = _connect
|
bus.connect.side_effect = _connect
|
||||||
bus.disconnect.side_effect = _disconnect
|
bus.disconnect.side_effect = _disconnect
|
||||||
bus.motors = _MOTORS
|
|
||||||
bus.is_calibrated = True
|
|
||||||
bus.sync_read.return_value = {m: i for i, m in enumerate(_MOTORS, 1)}
|
|
||||||
bus.sync_write.return_value = None
|
|
||||||
bus.write.return_value = None
|
|
||||||
bus.disable_torque.return_value = None
|
|
||||||
bus.enable_torque.return_value = None
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _dummy_cm():
|
def _dummy_cm():
|
||||||
@@ -43,10 +34,24 @@ def _make_bus_mock() -> MagicMock:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def follower():
|
def follower():
|
||||||
|
bus_mock = _make_bus_mock()
|
||||||
|
|
||||||
|
def _bus_side_effect(*_args, **kwargs):
|
||||||
|
bus_mock.motors = kwargs["motors"]
|
||||||
|
motors_order: list[str] = list(bus_mock.motors)
|
||||||
|
|
||||||
|
bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)}
|
||||||
|
bus_mock.sync_write.return_value = None
|
||||||
|
bus_mock.write.return_value = None
|
||||||
|
bus_mock.disable_torque.return_value = None
|
||||||
|
bus_mock.enable_torque.return_value = None
|
||||||
|
bus_mock.is_calibrated = True
|
||||||
|
return bus_mock
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
|
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
|
||||||
return_value=_make_bus_mock(),
|
side_effect=_bus_side_effect,
|
||||||
),
|
),
|
||||||
patch.object(SO100Follower, "configure", lambda self: None),
|
patch.object(SO100Follower, "configure", lambda self: None),
|
||||||
):
|
):
|
||||||
@@ -71,20 +76,20 @@ def test_get_observation(follower):
|
|||||||
follower.connect()
|
follower.connect()
|
||||||
obs = follower.get_observation()
|
obs = follower.get_observation()
|
||||||
|
|
||||||
expected_keys = {f"{m}.pos" for m in _MOTORS}
|
expected_keys = {f"{m}.pos" for m in follower.bus.motors}
|
||||||
assert set(obs.keys()) == expected_keys
|
assert set(obs.keys()) == expected_keys
|
||||||
|
|
||||||
for idx, motor in enumerate(_MOTORS, 1):
|
for idx, motor in enumerate(follower.bus.motors, 1):
|
||||||
assert obs[f"{motor}.pos"] == idx
|
assert obs[f"{motor}.pos"] == idx
|
||||||
|
|
||||||
|
|
||||||
def test_send_action(follower):
|
def test_send_action(follower):
|
||||||
follower.connect()
|
follower.connect()
|
||||||
|
|
||||||
action = {f"{m}.pos": i * 10 for i, m in enumerate(_MOTORS, 1)}
|
action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)}
|
||||||
returned = follower.send_action(action)
|
returned = follower.send_action(action)
|
||||||
|
|
||||||
assert returned == action
|
assert returned == action
|
||||||
|
|
||||||
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(_MOTORS)}
|
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)}
|
||||||
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)
|
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)
|
||||||
|
|||||||
Reference in New Issue
Block a user