forked from tangger/lerobot
Add test_motors_bus
This commit is contained in:
@@ -43,19 +43,18 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
|
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
|
||||||
try:
|
ctrl_table = model_ctrl_table.get(model)
|
||||||
return model_ctrl_table[model]
|
if ctrl_table is None:
|
||||||
except KeyError:
|
raise KeyError(f"Control table for {model=} not found.")
|
||||||
raise KeyError(f"Control table for {model=} not found.") from None
|
return ctrl_table
|
||||||
|
|
||||||
|
|
||||||
def get_address(model_ctrl_table: dict[str, dict], model: str, data_name: str) -> tuple[int, int]:
|
def get_address(model_ctrl_table: dict[str, dict], model: str, data_name: str) -> tuple[int, int]:
|
||||||
ctrl_table = get_ctrl_table(model_ctrl_table, model)
|
ctrl_table = get_ctrl_table(model_ctrl_table, model)
|
||||||
try:
|
addr_bytes = ctrl_table.get(data_name)
|
||||||
addr, bytes = ctrl_table[data_name]
|
if addr_bytes is None:
|
||||||
return addr, bytes
|
raise KeyError(f"Address for '{data_name}' not found in {model} control table.")
|
||||||
except KeyError:
|
return addr_bytes
|
||||||
raise KeyError(f"Address for '{data_name}' not found in {model} control table.") from None
|
|
||||||
|
|
||||||
|
|
||||||
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None:
|
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None:
|
||||||
@@ -69,13 +68,13 @@ def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[st
|
|||||||
if len(set(all_addr)) != 1:
|
if len(set(all_addr)) != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"At least two motor models use a different address for `data_name`='{data_name}'"
|
f"At least two motor models use a different address for `data_name`='{data_name}'"
|
||||||
f"({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
f"({list(zip(motor_models, all_addr, strict=False))})."
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(set(all_bytes)) != 1:
|
if len(set(all_bytes)) != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"At least two motor models use a different bytes representation for `data_name`='{data_name}'"
|
f"At least two motor models use a different bytes representation for `data_name`='{data_name}'"
|
||||||
f"({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
f"({list(zip(motor_models, all_bytes, strict=False))})."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
87
tests/motors/test_motors_bus.py
Normal file
87
tests/motors/test_motors_bus.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.motors.motors_bus import assert_same_address, get_address, get_ctrl_table
|
||||||
|
|
||||||
|
# TODO(aliberts)
|
||||||
|
# class DummyMotorsBus(MotorsBus):
|
||||||
|
# def __init__(self, port: str, motors: dict[str, Motor]):
|
||||||
|
# super().__init__(port, motors)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ctrl_table_1() -> dict:
|
||||||
|
return {
|
||||||
|
"Firmware_Version": (0, 1),
|
||||||
|
"Model_Number": (1, 2),
|
||||||
|
"Present_Position": (3, 4),
|
||||||
|
"Goal_Position": (7, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ctrl_table_2() -> dict:
|
||||||
|
return {
|
||||||
|
"Model_Number": (0, 2),
|
||||||
|
"Firmware_Version": (2, 1),
|
||||||
|
"Present_Position": (3, 4),
|
||||||
|
"Goal_Position": (7, 4),
|
||||||
|
"Lock": (7, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_ctrl_table(ctrl_table_1, ctrl_table_2) -> dict:
|
||||||
|
return {
|
||||||
|
"model_1": ctrl_table_1,
|
||||||
|
"model_2": ctrl_table_2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ctrl_table(model_ctrl_table, ctrl_table_1):
|
||||||
|
model = "model_1"
|
||||||
|
ctrl_table = get_ctrl_table(model_ctrl_table, model)
|
||||||
|
assert ctrl_table == ctrl_table_1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ctrl_table_error(model_ctrl_table):
|
||||||
|
model = "model_99"
|
||||||
|
with pytest.raises(KeyError, match=f"Control table for {model=} not found."):
|
||||||
|
get_ctrl_table(model_ctrl_table, model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_address(model_ctrl_table):
|
||||||
|
addr, n_bytes = get_address(model_ctrl_table, "model_1", "Firmware_Version")
|
||||||
|
assert addr == 0
|
||||||
|
assert n_bytes == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_address_error(model_ctrl_table):
|
||||||
|
model = "model_1"
|
||||||
|
data_name = "Lock"
|
||||||
|
with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."):
|
||||||
|
get_address(model_ctrl_table, "model_1", data_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_assert_same_address(model_ctrl_table):
|
||||||
|
models = ["model_1", "model_2"]
|
||||||
|
assert_same_address(model_ctrl_table, models, "Present_Position")
|
||||||
|
|
||||||
|
|
||||||
|
def test_assert_same_address_different_addresses(model_ctrl_table):
|
||||||
|
models = ["model_1", "model_2"]
|
||||||
|
with pytest.raises(
|
||||||
|
NotImplementedError,
|
||||||
|
match=re.escape("At least two motor models use a different address"),
|
||||||
|
):
|
||||||
|
assert_same_address(model_ctrl_table, models, "Model_Number")
|
||||||
|
|
||||||
|
|
||||||
|
def test_assert_same_address_different_bytes(model_ctrl_table):
|
||||||
|
models = ["model_1", "model_2"]
|
||||||
|
with pytest.raises(
|
||||||
|
NotImplementedError,
|
||||||
|
match=re.escape("At least two motor models use a different bytes representation"),
|
||||||
|
):
|
||||||
|
assert_same_address(model_ctrl_table, models, "Goal_Position")
|
||||||
Reference in New Issue
Block a user