Add so100_follower tests

This commit is contained in:
Simon Alibert
2025-05-19 10:58:35 +02:00
parent 10119c1a59
commit edbba48e81

View File

@@ -0,0 +1,90 @@
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from lerobot.common.robots.so100_follower import (
SO100Follower,
SO100FollowerConfig,
)
_MOTORS = SO100Follower(SO100FollowerConfig("")).bus.motors
def _make_bus_mock() -> MagicMock:
"""Return a bus mock with just the attributes used by the robot."""
bus = MagicMock(name="FeetechBusMock")
bus.is_connected = False
def _connect():
bus.is_connected = True
def _disconnect(_disable=True):
bus.is_connected = False
bus.connect.side_effect = _connect
bus.disconnect.side_effect = _disconnect
bus.motors = _MOTORS
bus.is_calibrated = True
bus.sync_read.return_value = {m: i for i, m in enumerate(_MOTORS, 1)}
bus.sync_write.return_value = None
bus.write.return_value = None
bus.disable_torque.return_value = None
bus.enable_torque.return_value = None
@contextmanager
def _dummy_cm():
yield
bus.torque_disabled.side_effect = _dummy_cm
return bus
@pytest.fixture
def follower():
with (
patch(
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
return_value=_make_bus_mock(),
),
patch.object(SO100Follower, "configure", lambda self: None),
):
cfg = SO100FollowerConfig(port="/dev/null")
robot = SO100Follower(cfg)
yield robot
if robot.is_connected:
robot.disconnect()
def test_connect_disconnect(follower):
assert not follower.is_connected
follower.connect()
assert follower.is_connected
follower.disconnect()
assert not follower.is_connected
def test_get_observation(follower):
follower.connect()
obs = follower.get_observation()
expected_keys = {f"{m}.pos" for m in _MOTORS}
assert set(obs.keys()) == expected_keys
for idx, motor in enumerate(_MOTORS, 1):
assert obs[f"{motor}.pos"] == idx
def test_send_action(follower):
follower.connect()
action = {f"{m}.pos": i * 10 for i, m in enumerate(_MOTORS, 1)}
returned = follower.send_action(action)
assert returned == action
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(_MOTORS)}
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)