forked from tangger/lerobot
Merge branch 'user/aliberts/2025_02_25_refactor_robots' into refactor/camera_implementations_and_tests_improvements
This commit is contained in:
@@ -1,68 +0,0 @@
|
||||
{
|
||||
"homing_offset": [
|
||||
2048,
|
||||
3072,
|
||||
3072,
|
||||
-1024,
|
||||
-1024,
|
||||
2048,
|
||||
-2048,
|
||||
2048,
|
||||
-2048
|
||||
],
|
||||
"drive_mode": [
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0
|
||||
],
|
||||
"start_pos": [
|
||||
2015,
|
||||
3058,
|
||||
3061,
|
||||
1071,
|
||||
1071,
|
||||
2035,
|
||||
2152,
|
||||
2029,
|
||||
2499
|
||||
],
|
||||
"end_pos": [
|
||||
-1008,
|
||||
-1963,
|
||||
-1966,
|
||||
2141,
|
||||
2143,
|
||||
-971,
|
||||
3043,
|
||||
-1077,
|
||||
3144
|
||||
],
|
||||
"calib_mode": [
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"LINEAR"
|
||||
],
|
||||
"motor_names": [
|
||||
"waist",
|
||||
"shoulder",
|
||||
"shoulder_shadow",
|
||||
"elbow",
|
||||
"elbow_shadow",
|
||||
"forearm_roll",
|
||||
"wrist_angle",
|
||||
"wrist_rotate",
|
||||
"gripper"
|
||||
]
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
{
|
||||
"homing_offset": [
|
||||
2048,
|
||||
3072,
|
||||
3072,
|
||||
-1024,
|
||||
-1024,
|
||||
2048,
|
||||
-2048,
|
||||
2048,
|
||||
-1024
|
||||
],
|
||||
"drive_mode": [
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0
|
||||
],
|
||||
"start_pos": [
|
||||
2035,
|
||||
3024,
|
||||
3019,
|
||||
979,
|
||||
981,
|
||||
1982,
|
||||
2166,
|
||||
2124,
|
||||
1968
|
||||
],
|
||||
"end_pos": [
|
||||
-990,
|
||||
-2017,
|
||||
-2015,
|
||||
2078,
|
||||
2076,
|
||||
-1030,
|
||||
3117,
|
||||
-1016,
|
||||
2556
|
||||
],
|
||||
"calib_mode": [
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"LINEAR"
|
||||
],
|
||||
"motor_names": [
|
||||
"waist",
|
||||
"shoulder",
|
||||
"shoulder_shadow",
|
||||
"elbow",
|
||||
"elbow_shadow",
|
||||
"forearm_roll",
|
||||
"wrist_angle",
|
||||
"wrist_rotate",
|
||||
"gripper"
|
||||
]
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
{
|
||||
"homing_offset": [
|
||||
2048,
|
||||
3072,
|
||||
3072,
|
||||
-1024,
|
||||
-1024,
|
||||
2048,
|
||||
-2048,
|
||||
2048,
|
||||
-2048
|
||||
],
|
||||
"drive_mode": [
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0
|
||||
],
|
||||
"start_pos": [
|
||||
2056,
|
||||
2895,
|
||||
2896,
|
||||
1191,
|
||||
1190,
|
||||
2018,
|
||||
2051,
|
||||
2056,
|
||||
2509
|
||||
],
|
||||
"end_pos": [
|
||||
-1040,
|
||||
-2004,
|
||||
-2006,
|
||||
2126,
|
||||
2127,
|
||||
-1010,
|
||||
3050,
|
||||
-1117,
|
||||
3143
|
||||
],
|
||||
"calib_mode": [
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"LINEAR"
|
||||
],
|
||||
"motor_names": [
|
||||
"waist",
|
||||
"shoulder",
|
||||
"shoulder_shadow",
|
||||
"elbow",
|
||||
"elbow_shadow",
|
||||
"forearm_roll",
|
||||
"wrist_angle",
|
||||
"wrist_rotate",
|
||||
"gripper"
|
||||
]
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
{
|
||||
"homing_offset": [
|
||||
2048,
|
||||
3072,
|
||||
3072,
|
||||
-1024,
|
||||
-1024,
|
||||
2048,
|
||||
-2048,
|
||||
2048,
|
||||
-2048
|
||||
],
|
||||
"drive_mode": [
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0
|
||||
],
|
||||
"start_pos": [
|
||||
2068,
|
||||
3034,
|
||||
3030,
|
||||
1038,
|
||||
1041,
|
||||
1991,
|
||||
1948,
|
||||
2090,
|
||||
1985
|
||||
],
|
||||
"end_pos": [
|
||||
-1025,
|
||||
-2014,
|
||||
-2015,
|
||||
2058,
|
||||
2060,
|
||||
-955,
|
||||
3091,
|
||||
-940,
|
||||
2576
|
||||
],
|
||||
"calib_mode": [
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"DEGREE",
|
||||
"LINEAR"
|
||||
],
|
||||
"motor_names": [
|
||||
"waist",
|
||||
"shoulder",
|
||||
"shoulder_shadow",
|
||||
"elbow",
|
||||
"elbow_shadow",
|
||||
"forearm_roll",
|
||||
"wrist_angle",
|
||||
"wrist_rotate",
|
||||
"gripper"
|
||||
]
|
||||
}
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -11,7 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Dev scripts
|
||||
.dev
|
||||
|
||||
# Logging
|
||||
logs
|
||||
tmp
|
||||
@@ -91,10 +94,8 @@ coverage.xml
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Ignore .cache except calibration
|
||||
# Ignore .cache
|
||||
.cache/*
|
||||
!.cache/calibration/
|
||||
!.cache/calibration/**
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
|
||||
@@ -32,7 +32,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def gym_kwargs(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -749,7 +749,10 @@ class MotorsBus(abc.ABC):
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(motors) + 3)
|
||||
|
||||
# TODO(Steven, aliberts): add check to ensure mins and maxes are different
|
||||
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
|
||||
if same_min_max:
|
||||
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
|
||||
|
||||
return mins, maxes
|
||||
|
||||
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import logging
|
||||
import time
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
@@ -183,6 +184,12 @@ class LeKiwi(Robot):
|
||||
|
||||
self.bus.enable_torque()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
for motor in chain(reversed(self.arm_motors), reversed(self.base_motors)):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -49,15 +49,18 @@ class Robot(abc.ABC):
|
||||
return f"{self.id} {self.__class__.__name__}"
|
||||
|
||||
# TODO(aliberts): create a proper Feature class for this that links with datasets
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_features(self) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_features(self) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
pass
|
||||
|
||||
@@ -66,7 +69,8 @@ class Robot(abc.ABC):
|
||||
"""Connects to the robot."""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_calibrated(self) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
@@ -47,15 +47,18 @@ class Teleoperator(abc.ABC):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.id} {self.__class__.__name__}"
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_features(self) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def feedback_features(self) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
pass
|
||||
|
||||
@@ -64,7 +67,8 @@ class Teleoperator(abc.ABC):
|
||||
"""Connects to the teleoperator."""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_calibrated(self) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
@@ -78,15 +78,18 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
137
tests/mocks/mock_motors_bus.py
Normal file
137
tests/mocks/mock_motors_bus.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# ruff: noqa: N802
|
||||
|
||||
from lerobot.common.motors.motors_bus import (
|
||||
Motor,
|
||||
MotorsBus,
|
||||
)
|
||||
|
||||
DUMMY_CTRL_TABLE_1 = {
|
||||
"Firmware_Version": (0, 1),
|
||||
"Model_Number": (1, 2),
|
||||
"Present_Position": (3, 4),
|
||||
"Goal_Position": (11, 2),
|
||||
}
|
||||
|
||||
DUMMY_CTRL_TABLE_2 = {
|
||||
"Model_Number": (0, 2),
|
||||
"Firmware_Version": (2, 1),
|
||||
"Present_Position": (3, 4),
|
||||
"Present_Velocity": (7, 4),
|
||||
"Goal_Position": (11, 4),
|
||||
"Goal_Velocity": (15, 4),
|
||||
"Lock": (19, 1),
|
||||
}
|
||||
|
||||
DUMMY_MODEL_CTRL_TABLE = {
|
||||
"model_1": DUMMY_CTRL_TABLE_1,
|
||||
"model_2": DUMMY_CTRL_TABLE_2,
|
||||
"model_3": DUMMY_CTRL_TABLE_2,
|
||||
}
|
||||
|
||||
DUMMY_BAUDRATE_TABLE = {
|
||||
0: 1_000_000,
|
||||
1: 500_000,
|
||||
2: 250_000,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_BAUDRATE_TABLE = {
|
||||
"model_1": DUMMY_BAUDRATE_TABLE,
|
||||
"model_2": DUMMY_BAUDRATE_TABLE,
|
||||
"model_3": DUMMY_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
DUMMY_ENCODING_TABLE = {
|
||||
"Present_Position": 8,
|
||||
"Goal_Position": 10,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_ENCODING_TABLE = {
|
||||
"model_1": DUMMY_ENCODING_TABLE,
|
||||
"model_2": DUMMY_ENCODING_TABLE,
|
||||
"model_3": DUMMY_ENCODING_TABLE,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_NUMBER_TABLE = {
|
||||
"model_1": 1234,
|
||||
"model_2": 5678,
|
||||
"model_3": 5799,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_RESOLUTION_TABLE = {
|
||||
"model_1": 4096,
|
||||
"model_2": 1024,
|
||||
"model_3": 4096,
|
||||
}
|
||||
|
||||
|
||||
class MockPortHandler:
|
||||
def __init__(self, port_name):
|
||||
self.is_open: bool = False
|
||||
self.baudrate: int
|
||||
self.packet_start_time: float
|
||||
self.packet_timeout: float
|
||||
self.tx_time_per_byte: float
|
||||
self.is_using: bool = False
|
||||
self.port_name: str = port_name
|
||||
self.ser = None
|
||||
|
||||
def openPort(self):
|
||||
self.is_open = True
|
||||
return self.is_open
|
||||
|
||||
def closePort(self):
|
||||
self.is_open = False
|
||||
|
||||
def clearPort(self): ...
|
||||
def setPortName(self, port_name):
|
||||
self.port_name = port_name
|
||||
|
||||
def getPortName(self):
|
||||
return self.port_name
|
||||
|
||||
def setBaudRate(self, baudrate):
|
||||
self.baudrate: baudrate
|
||||
|
||||
def getBaudRate(self):
|
||||
return self.baudrate
|
||||
|
||||
def getBytesAvailable(self): ...
|
||||
def readPort(self, length): ...
|
||||
def writePort(self, packet): ...
|
||||
def setPacketTimeout(self, packet_length): ...
|
||||
def setPacketTimeoutMillis(self, msec): ...
|
||||
def isPacketTimeout(self): ...
|
||||
def getCurrentTime(self): ...
|
||||
def getTimeSinceStart(self): ...
|
||||
def setupPort(self, cflag_baud): ...
|
||||
def getCFlagBaud(self, baudrate): ...
|
||||
|
||||
|
||||
class MockMotorsBus(MotorsBus):
|
||||
available_baudrates = [500_000, 1_000_000]
|
||||
default_timeout = 1000
|
||||
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
|
||||
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
|
||||
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
|
||||
model_number_table = DUMMY_MODEL_NUMBER_TABLE
|
||||
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
|
||||
normalized_data = ["Present_Position", "Goal_Position"]
|
||||
|
||||
def __init__(self, port: str, motors: dict[str, Motor]):
|
||||
super().__init__(port, motors)
|
||||
self.port_handler = MockPortHandler(port)
|
||||
|
||||
def _assert_protocol_is_compatible(self, instruction_name): ...
|
||||
def _handshake(self): ...
|
||||
def _find_single_motor(self, motor, initial_baudrate): ...
|
||||
def configure_motors(self): ...
|
||||
def read_calibration(self): ...
|
||||
def write_calibration(self, calibration_dict): ...
|
||||
def disable_torque(self, motors, num_retry): ...
|
||||
def _disable_torque(self, motor, model, num_retry): ...
|
||||
def enable_torque(self, motors, num_retry): ...
|
||||
def _get_half_turn_homings(self, positions): ...
|
||||
def _encode_sign(self, data_name, ids_values): ...
|
||||
def _decode_sign(self, data_name, ids_values): ...
|
||||
def _split_into_byte_chunks(self, value, length): ...
|
||||
def broadcast_ping(self, num_retry, raise_on_error): ...
|
||||
@@ -1,5 +1,3 @@
|
||||
# ruff: noqa: N802
|
||||
|
||||
import re
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -8,142 +6,16 @@ import pytest
|
||||
from lerobot.common.motors.motors_bus import (
|
||||
Motor,
|
||||
MotorNormMode,
|
||||
MotorsBus,
|
||||
assert_same_address,
|
||||
get_address,
|
||||
get_ctrl_table,
|
||||
)
|
||||
|
||||
DUMMY_CTRL_TABLE_1 = {
|
||||
"Firmware_Version": (0, 1),
|
||||
"Model_Number": (1, 2),
|
||||
"Present_Position": (3, 4),
|
||||
"Goal_Position": (11, 2),
|
||||
}
|
||||
|
||||
DUMMY_CTRL_TABLE_2 = {
|
||||
"Model_Number": (0, 2),
|
||||
"Firmware_Version": (2, 1),
|
||||
"Present_Position": (3, 4),
|
||||
"Present_Velocity": (7, 4),
|
||||
"Goal_Position": (11, 4),
|
||||
"Goal_Velocity": (15, 4),
|
||||
"Lock": (19, 1),
|
||||
}
|
||||
|
||||
DUMMY_MODEL_CTRL_TABLE = {
|
||||
"model_1": DUMMY_CTRL_TABLE_1,
|
||||
"model_2": DUMMY_CTRL_TABLE_2,
|
||||
"model_3": DUMMY_CTRL_TABLE_2,
|
||||
}
|
||||
|
||||
DUMMY_BAUDRATE_TABLE = {
|
||||
0: 1_000_000,
|
||||
1: 500_000,
|
||||
2: 250_000,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_BAUDRATE_TABLE = {
|
||||
"model_1": DUMMY_BAUDRATE_TABLE,
|
||||
"model_2": DUMMY_BAUDRATE_TABLE,
|
||||
"model_3": DUMMY_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
DUMMY_ENCODING_TABLE = {
|
||||
"Present_Position": 8,
|
||||
"Goal_Position": 10,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_ENCODING_TABLE = {
|
||||
"model_1": DUMMY_ENCODING_TABLE,
|
||||
"model_2": DUMMY_ENCODING_TABLE,
|
||||
"model_3": DUMMY_ENCODING_TABLE,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_NUMBER_TABLE = {
|
||||
"model_1": 1234,
|
||||
"model_2": 5678,
|
||||
"model_3": 5799,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_RESOLUTION_TABLE = {
|
||||
"model_1": 4096,
|
||||
"model_2": 1024,
|
||||
"model_3": 4096,
|
||||
}
|
||||
|
||||
|
||||
class MockPortHandler:
|
||||
def __init__(self, port_name):
|
||||
self.is_open: bool = False
|
||||
self.baudrate: int
|
||||
self.packet_start_time: float
|
||||
self.packet_timeout: float
|
||||
self.tx_time_per_byte: float
|
||||
self.is_using: bool = False
|
||||
self.port_name: str = port_name
|
||||
self.ser = None
|
||||
|
||||
def openPort(self):
|
||||
self.is_open = True
|
||||
return self.is_open
|
||||
|
||||
def closePort(self):
|
||||
self.is_open = False
|
||||
|
||||
def clearPort(self): ...
|
||||
def setPortName(self, port_name):
|
||||
self.port_name = port_name
|
||||
|
||||
def getPortName(self):
|
||||
return self.port_name
|
||||
|
||||
def setBaudRate(self, baudrate):
|
||||
self.baudrate: baudrate
|
||||
|
||||
def getBaudRate(self):
|
||||
return self.baudrate
|
||||
|
||||
def getBytesAvailable(self): ...
|
||||
def readPort(self, length): ...
|
||||
def writePort(self, packet): ...
|
||||
def setPacketTimeout(self, packet_length): ...
|
||||
def setPacketTimeoutMillis(self, msec): ...
|
||||
def isPacketTimeout(self): ...
|
||||
def getCurrentTime(self): ...
|
||||
def getTimeSinceStart(self): ...
|
||||
def setupPort(self, cflag_baud): ...
|
||||
def getCFlagBaud(self, baudrate): ...
|
||||
|
||||
|
||||
class MockMotorsBus(MotorsBus):
|
||||
available_baudrates = [500_000, 1_000_000]
|
||||
default_timeout = 1000
|
||||
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
|
||||
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
|
||||
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
|
||||
model_number_table = DUMMY_MODEL_NUMBER_TABLE
|
||||
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
|
||||
normalized_data = ["Present_Position", "Goal_Position"]
|
||||
|
||||
def __init__(self, port: str, motors: dict[str, Motor]):
|
||||
super().__init__(port, motors)
|
||||
self.port_handler = MockPortHandler(port)
|
||||
|
||||
def _assert_protocol_is_compatible(self, instruction_name): ...
|
||||
def _handshake(self): ...
|
||||
def _find_single_motor(self, motor, initial_baudrate): ...
|
||||
def configure_motors(self): ...
|
||||
def read_calibration(self): ...
|
||||
def write_calibration(self, calibration_dict): ...
|
||||
def disable_torque(self, motors, num_retry): ...
|
||||
def _disable_torque(self, motor, model, num_retry): ...
|
||||
def enable_torque(self, motors, num_retry): ...
|
||||
def _get_half_turn_homings(self, positions): ...
|
||||
def _encode_sign(self, data_name, ids_values): ...
|
||||
def _decode_sign(self, data_name, ids_values): ...
|
||||
def _split_into_byte_chunks(self, value, length): ...
|
||||
def broadcast_ping(self, num_retry, raise_on_error): ...
|
||||
from tests.mocks.mock_motors_bus import (
|
||||
DUMMY_CTRL_TABLE_1,
|
||||
DUMMY_CTRL_TABLE_2,
|
||||
DUMMY_MODEL_CTRL_TABLE,
|
||||
MockMotorsBus,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
95
tests/robots/test_so100_follower.py
Normal file
95
tests/robots/test_so100_follower.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.robots.so100_follower import (
|
||||
SO100Follower,
|
||||
SO100FollowerConfig,
|
||||
)
|
||||
|
||||
|
||||
def _make_bus_mock() -> MagicMock:
|
||||
"""Return a bus mock with just the attributes used by the robot."""
|
||||
bus = MagicMock(name="FeetechBusMock")
|
||||
bus.is_connected = False
|
||||
|
||||
def _connect():
|
||||
bus.is_connected = True
|
||||
|
||||
def _disconnect(_disable=True):
|
||||
bus.is_connected = False
|
||||
|
||||
bus.connect.side_effect = _connect
|
||||
bus.disconnect.side_effect = _disconnect
|
||||
|
||||
@contextmanager
|
||||
def _dummy_cm():
|
||||
yield
|
||||
|
||||
bus.torque_disabled.side_effect = _dummy_cm
|
||||
|
||||
return bus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def follower():
|
||||
bus_mock = _make_bus_mock()
|
||||
|
||||
def _bus_side_effect(*_args, **kwargs):
|
||||
bus_mock.motors = kwargs["motors"]
|
||||
motors_order: list[str] = list(bus_mock.motors)
|
||||
|
||||
bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)}
|
||||
bus_mock.sync_write.return_value = None
|
||||
bus_mock.write.return_value = None
|
||||
bus_mock.disable_torque.return_value = None
|
||||
bus_mock.enable_torque.return_value = None
|
||||
bus_mock.is_calibrated = True
|
||||
return bus_mock
|
||||
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
|
||||
side_effect=_bus_side_effect,
|
||||
),
|
||||
patch.object(SO100Follower, "configure", lambda self: None),
|
||||
):
|
||||
cfg = SO100FollowerConfig(port="/dev/null")
|
||||
robot = SO100Follower(cfg)
|
||||
yield robot
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_connect_disconnect(follower):
|
||||
assert not follower.is_connected
|
||||
|
||||
follower.connect()
|
||||
assert follower.is_connected
|
||||
|
||||
follower.disconnect()
|
||||
assert not follower.is_connected
|
||||
|
||||
|
||||
def test_get_observation(follower):
|
||||
follower.connect()
|
||||
obs = follower.get_observation()
|
||||
|
||||
expected_keys = {f"{m}.pos" for m in follower.bus.motors}
|
||||
assert set(obs.keys()) == expected_keys
|
||||
|
||||
for idx, motor in enumerate(follower.bus.motors, 1):
|
||||
assert obs[f"{motor}.pos"] == idx
|
||||
|
||||
|
||||
def test_send_action(follower):
|
||||
follower.connect()
|
||||
|
||||
action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)}
|
||||
returned = follower.send_action(action)
|
||||
|
||||
assert returned == action
|
||||
|
||||
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)}
|
||||
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)
|
||||
Reference in New Issue
Block a user