From b71ac342142c8c99bab9b4662d22ea7c12f9815e Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 25 Mar 2025 12:11:56 +0100 Subject: [PATCH] Add test_motors_bus --- lerobot/common/motors/motors_bus.py | 21 ++++--- tests/motors/test_motors_bus.py | 87 +++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 11 deletions(-) create mode 100644 tests/motors/test_motors_bus.py diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 37b1fd77..d5265d8c 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -43,19 +43,18 @@ logger = logging.getLogger(__name__) def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]: - try: - return model_ctrl_table[model] - except KeyError: - raise KeyError(f"Control table for {model=} not found.") from None + ctrl_table = model_ctrl_table.get(model) + if ctrl_table is None: + raise KeyError(f"Control table for {model=} not found.") + return ctrl_table def get_address(model_ctrl_table: dict[str, dict], model: str, data_name: str) -> tuple[int, int]: ctrl_table = get_ctrl_table(model_ctrl_table, model) - try: - addr, bytes = ctrl_table[data_name] - return addr, bytes - except KeyError: - raise KeyError(f"Address for '{data_name}' not found in {model} control table.") from None + addr_bytes = ctrl_table.get(data_name) + if addr_bytes is None: + raise KeyError(f"Address for '{data_name}' not found in {model} control table.") + return addr_bytes def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None: @@ -69,13 +68,13 @@ def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[st if len(set(all_addr)) != 1: raise NotImplementedError( f"At least two motor models use a different address for `data_name`='{data_name}'" - f"({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer." + f"({list(zip(motor_models, all_addr, strict=False))})." ) if len(set(all_bytes)) != 1: raise NotImplementedError( f"At least two motor models use a different bytes representation for `data_name`='{data_name}'" - f"({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer." + f"({list(zip(motor_models, all_bytes, strict=False))})." ) diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py new file mode 100644 index 00000000..7463ae8c --- /dev/null +++ b/tests/motors/test_motors_bus.py @@ -0,0 +1,87 @@ +import re + +import pytest + +from lerobot.common.motors.motors_bus import assert_same_address, get_address, get_ctrl_table + +# TODO(aliberts) +# class DummyMotorsBus(MotorsBus): +# def __init__(self, port: str, motors: dict[str, Motor]): +# super().__init__(port, motors) + + +@pytest.fixture +def ctrl_table_1() -> dict: + return { + "Firmware_Version": (0, 1), + "Model_Number": (1, 2), + "Present_Position": (3, 4), + "Goal_Position": (7, 2), + } + + +@pytest.fixture +def ctrl_table_2() -> dict: + return { + "Model_Number": (0, 2), + "Firmware_Version": (2, 1), + "Present_Position": (3, 4), + "Goal_Position": (7, 4), + "Lock": (7, 4), + } + + +@pytest.fixture +def model_ctrl_table(ctrl_table_1, ctrl_table_2) -> dict: + return { + "model_1": ctrl_table_1, + "model_2": ctrl_table_2, + } + + +def test_get_ctrl_table(model_ctrl_table, ctrl_table_1): + model = "model_1" + ctrl_table = get_ctrl_table(model_ctrl_table, model) + assert ctrl_table == ctrl_table_1 + + +def test_get_ctrl_table_error(model_ctrl_table): + model = "model_99" + with pytest.raises(KeyError, match=f"Control table for {model=} not found."): + get_ctrl_table(model_ctrl_table, model) + + +def test_get_address(model_ctrl_table): + addr, n_bytes = get_address(model_ctrl_table, "model_1", "Firmware_Version") + assert addr == 0 + assert n_bytes == 1 + + +def test_get_address_error(model_ctrl_table): + model = "model_1" + data_name = "Lock" + with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."): + get_address(model_ctrl_table, "model_1", data_name) + + +def test_assert_same_address(model_ctrl_table): + models = ["model_1", "model_2"] + assert_same_address(model_ctrl_table, models, "Present_Position") + + +def test_assert_same_address_different_addresses(model_ctrl_table): + models = ["model_1", "model_2"] + with pytest.raises( + NotImplementedError, + match=re.escape("At least two motor models use a different address"), + ): + assert_same_address(model_ctrl_table, models, "Model_Number") + + +def test_assert_same_address_different_bytes(model_ctrl_table): + models = ["model_1", "model_2"] + with pytest.raises( + NotImplementedError, + match=re.escape("At least two motor models use a different bytes representation"), + ): + assert_same_address(model_ctrl_table, models, "Goal_Position")