Merge branch 'user/aliberts/2025_02_25_refactor_robots' into refactor/camera_implementations_and_tests_improvements
This commit is contained in:
@@ -1,68 +0,0 @@
|
|||||||
{
|
|
||||||
"homing_offset": [
|
|
||||||
2048,
|
|
||||||
3072,
|
|
||||||
3072,
|
|
||||||
-1024,
|
|
||||||
-1024,
|
|
||||||
2048,
|
|
||||||
-2048,
|
|
||||||
2048,
|
|
||||||
-2048
|
|
||||||
],
|
|
||||||
"drive_mode": [
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"start_pos": [
|
|
||||||
2015,
|
|
||||||
3058,
|
|
||||||
3061,
|
|
||||||
1071,
|
|
||||||
1071,
|
|
||||||
2035,
|
|
||||||
2152,
|
|
||||||
2029,
|
|
||||||
2499
|
|
||||||
],
|
|
||||||
"end_pos": [
|
|
||||||
-1008,
|
|
||||||
-1963,
|
|
||||||
-1966,
|
|
||||||
2141,
|
|
||||||
2143,
|
|
||||||
-971,
|
|
||||||
3043,
|
|
||||||
-1077,
|
|
||||||
3144
|
|
||||||
],
|
|
||||||
"calib_mode": [
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"LINEAR"
|
|
||||||
],
|
|
||||||
"motor_names": [
|
|
||||||
"waist",
|
|
||||||
"shoulder",
|
|
||||||
"shoulder_shadow",
|
|
||||||
"elbow",
|
|
||||||
"elbow_shadow",
|
|
||||||
"forearm_roll",
|
|
||||||
"wrist_angle",
|
|
||||||
"wrist_rotate",
|
|
||||||
"gripper"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
{
|
|
||||||
"homing_offset": [
|
|
||||||
2048,
|
|
||||||
3072,
|
|
||||||
3072,
|
|
||||||
-1024,
|
|
||||||
-1024,
|
|
||||||
2048,
|
|
||||||
-2048,
|
|
||||||
2048,
|
|
||||||
-1024
|
|
||||||
],
|
|
||||||
"drive_mode": [
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"start_pos": [
|
|
||||||
2035,
|
|
||||||
3024,
|
|
||||||
3019,
|
|
||||||
979,
|
|
||||||
981,
|
|
||||||
1982,
|
|
||||||
2166,
|
|
||||||
2124,
|
|
||||||
1968
|
|
||||||
],
|
|
||||||
"end_pos": [
|
|
||||||
-990,
|
|
||||||
-2017,
|
|
||||||
-2015,
|
|
||||||
2078,
|
|
||||||
2076,
|
|
||||||
-1030,
|
|
||||||
3117,
|
|
||||||
-1016,
|
|
||||||
2556
|
|
||||||
],
|
|
||||||
"calib_mode": [
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"LINEAR"
|
|
||||||
],
|
|
||||||
"motor_names": [
|
|
||||||
"waist",
|
|
||||||
"shoulder",
|
|
||||||
"shoulder_shadow",
|
|
||||||
"elbow",
|
|
||||||
"elbow_shadow",
|
|
||||||
"forearm_roll",
|
|
||||||
"wrist_angle",
|
|
||||||
"wrist_rotate",
|
|
||||||
"gripper"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
{
|
|
||||||
"homing_offset": [
|
|
||||||
2048,
|
|
||||||
3072,
|
|
||||||
3072,
|
|
||||||
-1024,
|
|
||||||
-1024,
|
|
||||||
2048,
|
|
||||||
-2048,
|
|
||||||
2048,
|
|
||||||
-2048
|
|
||||||
],
|
|
||||||
"drive_mode": [
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"start_pos": [
|
|
||||||
2056,
|
|
||||||
2895,
|
|
||||||
2896,
|
|
||||||
1191,
|
|
||||||
1190,
|
|
||||||
2018,
|
|
||||||
2051,
|
|
||||||
2056,
|
|
||||||
2509
|
|
||||||
],
|
|
||||||
"end_pos": [
|
|
||||||
-1040,
|
|
||||||
-2004,
|
|
||||||
-2006,
|
|
||||||
2126,
|
|
||||||
2127,
|
|
||||||
-1010,
|
|
||||||
3050,
|
|
||||||
-1117,
|
|
||||||
3143
|
|
||||||
],
|
|
||||||
"calib_mode": [
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"LINEAR"
|
|
||||||
],
|
|
||||||
"motor_names": [
|
|
||||||
"waist",
|
|
||||||
"shoulder",
|
|
||||||
"shoulder_shadow",
|
|
||||||
"elbow",
|
|
||||||
"elbow_shadow",
|
|
||||||
"forearm_roll",
|
|
||||||
"wrist_angle",
|
|
||||||
"wrist_rotate",
|
|
||||||
"gripper"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
{
|
|
||||||
"homing_offset": [
|
|
||||||
2048,
|
|
||||||
3072,
|
|
||||||
3072,
|
|
||||||
-1024,
|
|
||||||
-1024,
|
|
||||||
2048,
|
|
||||||
-2048,
|
|
||||||
2048,
|
|
||||||
-2048
|
|
||||||
],
|
|
||||||
"drive_mode": [
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"start_pos": [
|
|
||||||
2068,
|
|
||||||
3034,
|
|
||||||
3030,
|
|
||||||
1038,
|
|
||||||
1041,
|
|
||||||
1991,
|
|
||||||
1948,
|
|
||||||
2090,
|
|
||||||
1985
|
|
||||||
],
|
|
||||||
"end_pos": [
|
|
||||||
-1025,
|
|
||||||
-2014,
|
|
||||||
-2015,
|
|
||||||
2058,
|
|
||||||
2060,
|
|
||||||
-955,
|
|
||||||
3091,
|
|
||||||
-940,
|
|
||||||
2576
|
|
||||||
],
|
|
||||||
"calib_mode": [
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"DEGREE",
|
|
||||||
"LINEAR"
|
|
||||||
],
|
|
||||||
"motor_names": [
|
|
||||||
"waist",
|
|
||||||
"shoulder",
|
|
||||||
"shoulder_shadow",
|
|
||||||
"elbow",
|
|
||||||
"elbow_shadow",
|
|
||||||
"forearm_roll",
|
|
||||||
"wrist_angle",
|
|
||||||
"wrist_rotate",
|
|
||||||
"gripper"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -11,7 +11,10 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Dev scripts
|
||||||
.dev
|
.dev
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
logs
|
logs
|
||||||
tmp
|
tmp
|
||||||
@@ -91,10 +94,8 @@ coverage.xml
|
|||||||
.hypothesis/
|
.hypothesis/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
|
|
||||||
# Ignore .cache except calibration
|
# Ignore .cache
|
||||||
.cache/*
|
.cache/*
|
||||||
!.cache/calibration/
|
|
||||||
!.cache/calibration/**
|
|
||||||
|
|
||||||
# Translations
|
# Translations
|
||||||
*.mo
|
*.mo
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return self.get_choice_name(self.__class__)
|
return self.get_choice_name(self.__class__)
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -749,7 +749,10 @@ class MotorsBus(abc.ABC):
|
|||||||
# Move cursor up to overwrite the previous output
|
# Move cursor up to overwrite the previous output
|
||||||
move_cursor_up(len(motors) + 3)
|
move_cursor_up(len(motors) + 3)
|
||||||
|
|
||||||
# TODO(Steven, aliberts): add check to ensure mins and maxes are different
|
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
|
||||||
|
if same_min_max:
|
||||||
|
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
|
||||||
|
|
||||||
return mins, maxes
|
return mins, maxes
|
||||||
|
|
||||||
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:
|
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from itertools import chain
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||||
@@ -183,6 +184,12 @@ class LeKiwi(Robot):
|
|||||||
|
|
||||||
self.bus.enable_torque()
|
self.bus.enable_torque()
|
||||||
|
|
||||||
|
def setup_motors(self) -> None:
|
||||||
|
for motor in chain(reversed(self.arm_motors), reversed(self.base_motors)):
|
||||||
|
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||||
|
self.bus.setup_motor(motor)
|
||||||
|
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, Any]:
|
def get_observation(self) -> dict[str, Any]:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|||||||
@@ -49,15 +49,18 @@ class Robot(abc.ABC):
|
|||||||
return f"{self.id} {self.__class__.__name__}"
|
return f"{self.id} {self.__class__.__name__}"
|
||||||
|
|
||||||
# TODO(aliberts): create a proper Feature class for this that links with datasets
|
# TODO(aliberts): create a proper Feature class for this that links with datasets
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def observation_features(self) -> dict:
|
def observation_features(self) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def action_features(self) -> dict:
|
def action_features(self) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -66,7 +69,8 @@ class Robot(abc.ABC):
|
|||||||
"""Connects to the robot."""
|
"""Connects to the robot."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def is_calibrated(self) -> bool:
|
def is_calibrated(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -47,15 +47,18 @@ class Teleoperator(abc.ABC):
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.id} {self.__class__.__name__}"
|
return f"{self.id} {self.__class__.__name__}"
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def action_features(self) -> dict:
|
def action_features(self) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def feedback_features(self) -> dict:
|
def feedback_features(self) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -64,7 +67,8 @@ class Teleoperator(abc.ABC):
|
|||||||
"""Connects to the teleoperator."""
|
"""Connects to the teleoperator."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def is_calibrated(self) -> bool:
|
def is_calibrated(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -78,15 +78,18 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return self.get_choice_name(self.__class__)
|
return self.get_choice_name(self.__class__)
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def observation_delta_indices(self) -> list | None:
|
def observation_delta_indices(self) -> list | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def action_delta_indices(self) -> list | None:
|
def action_delta_indices(self) -> list | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def reward_delta_indices(self) -> list | None:
|
def reward_delta_indices(self) -> list | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
137
tests/mocks/mock_motors_bus.py
Normal file
137
tests/mocks/mock_motors_bus.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# ruff: noqa: N802
|
||||||
|
|
||||||
|
from lerobot.common.motors.motors_bus import (
|
||||||
|
Motor,
|
||||||
|
MotorsBus,
|
||||||
|
)
|
||||||
|
|
||||||
|
DUMMY_CTRL_TABLE_1 = {
|
||||||
|
"Firmware_Version": (0, 1),
|
||||||
|
"Model_Number": (1, 2),
|
||||||
|
"Present_Position": (3, 4),
|
||||||
|
"Goal_Position": (11, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_CTRL_TABLE_2 = {
|
||||||
|
"Model_Number": (0, 2),
|
||||||
|
"Firmware_Version": (2, 1),
|
||||||
|
"Present_Position": (3, 4),
|
||||||
|
"Present_Velocity": (7, 4),
|
||||||
|
"Goal_Position": (11, 4),
|
||||||
|
"Goal_Velocity": (15, 4),
|
||||||
|
"Lock": (19, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_CTRL_TABLE = {
|
||||||
|
"model_1": DUMMY_CTRL_TABLE_1,
|
||||||
|
"model_2": DUMMY_CTRL_TABLE_2,
|
||||||
|
"model_3": DUMMY_CTRL_TABLE_2,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_BAUDRATE_TABLE = {
|
||||||
|
0: 1_000_000,
|
||||||
|
1: 500_000,
|
||||||
|
2: 250_000,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_BAUDRATE_TABLE = {
|
||||||
|
"model_1": DUMMY_BAUDRATE_TABLE,
|
||||||
|
"model_2": DUMMY_BAUDRATE_TABLE,
|
||||||
|
"model_3": DUMMY_BAUDRATE_TABLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_ENCODING_TABLE = {
|
||||||
|
"Present_Position": 8,
|
||||||
|
"Goal_Position": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_ENCODING_TABLE = {
|
||||||
|
"model_1": DUMMY_ENCODING_TABLE,
|
||||||
|
"model_2": DUMMY_ENCODING_TABLE,
|
||||||
|
"model_3": DUMMY_ENCODING_TABLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_NUMBER_TABLE = {
|
||||||
|
"model_1": 1234,
|
||||||
|
"model_2": 5678,
|
||||||
|
"model_3": 5799,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_RESOLUTION_TABLE = {
|
||||||
|
"model_1": 4096,
|
||||||
|
"model_2": 1024,
|
||||||
|
"model_3": 4096,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MockPortHandler:
|
||||||
|
def __init__(self, port_name):
|
||||||
|
self.is_open: bool = False
|
||||||
|
self.baudrate: int
|
||||||
|
self.packet_start_time: float
|
||||||
|
self.packet_timeout: float
|
||||||
|
self.tx_time_per_byte: float
|
||||||
|
self.is_using: bool = False
|
||||||
|
self.port_name: str = port_name
|
||||||
|
self.ser = None
|
||||||
|
|
||||||
|
def openPort(self):
|
||||||
|
self.is_open = True
|
||||||
|
return self.is_open
|
||||||
|
|
||||||
|
def closePort(self):
|
||||||
|
self.is_open = False
|
||||||
|
|
||||||
|
def clearPort(self): ...
|
||||||
|
def setPortName(self, port_name):
|
||||||
|
self.port_name = port_name
|
||||||
|
|
||||||
|
def getPortName(self):
|
||||||
|
return self.port_name
|
||||||
|
|
||||||
|
def setBaudRate(self, baudrate):
|
||||||
|
self.baudrate: baudrate
|
||||||
|
|
||||||
|
def getBaudRate(self):
|
||||||
|
return self.baudrate
|
||||||
|
|
||||||
|
def getBytesAvailable(self): ...
|
||||||
|
def readPort(self, length): ...
|
||||||
|
def writePort(self, packet): ...
|
||||||
|
def setPacketTimeout(self, packet_length): ...
|
||||||
|
def setPacketTimeoutMillis(self, msec): ...
|
||||||
|
def isPacketTimeout(self): ...
|
||||||
|
def getCurrentTime(self): ...
|
||||||
|
def getTimeSinceStart(self): ...
|
||||||
|
def setupPort(self, cflag_baud): ...
|
||||||
|
def getCFlagBaud(self, baudrate): ...
|
||||||
|
|
||||||
|
|
||||||
|
class MockMotorsBus(MotorsBus):
|
||||||
|
available_baudrates = [500_000, 1_000_000]
|
||||||
|
default_timeout = 1000
|
||||||
|
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
|
||||||
|
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
|
||||||
|
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
|
||||||
|
model_number_table = DUMMY_MODEL_NUMBER_TABLE
|
||||||
|
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
|
||||||
|
normalized_data = ["Present_Position", "Goal_Position"]
|
||||||
|
|
||||||
|
def __init__(self, port: str, motors: dict[str, Motor]):
|
||||||
|
super().__init__(port, motors)
|
||||||
|
self.port_handler = MockPortHandler(port)
|
||||||
|
|
||||||
|
def _assert_protocol_is_compatible(self, instruction_name): ...
|
||||||
|
def _handshake(self): ...
|
||||||
|
def _find_single_motor(self, motor, initial_baudrate): ...
|
||||||
|
def configure_motors(self): ...
|
||||||
|
def read_calibration(self): ...
|
||||||
|
def write_calibration(self, calibration_dict): ...
|
||||||
|
def disable_torque(self, motors, num_retry): ...
|
||||||
|
def _disable_torque(self, motor, model, num_retry): ...
|
||||||
|
def enable_torque(self, motors, num_retry): ...
|
||||||
|
def _get_half_turn_homings(self, positions): ...
|
||||||
|
def _encode_sign(self, data_name, ids_values): ...
|
||||||
|
def _decode_sign(self, data_name, ids_values): ...
|
||||||
|
def _split_into_byte_chunks(self, value, length): ...
|
||||||
|
def broadcast_ping(self, num_retry, raise_on_error): ...
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
# ruff: noqa: N802
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@@ -8,142 +6,16 @@ import pytest
|
|||||||
from lerobot.common.motors.motors_bus import (
|
from lerobot.common.motors.motors_bus import (
|
||||||
Motor,
|
Motor,
|
||||||
MotorNormMode,
|
MotorNormMode,
|
||||||
MotorsBus,
|
|
||||||
assert_same_address,
|
assert_same_address,
|
||||||
get_address,
|
get_address,
|
||||||
get_ctrl_table,
|
get_ctrl_table,
|
||||||
)
|
)
|
||||||
|
from tests.mocks.mock_motors_bus import (
|
||||||
DUMMY_CTRL_TABLE_1 = {
|
DUMMY_CTRL_TABLE_1,
|
||||||
"Firmware_Version": (0, 1),
|
DUMMY_CTRL_TABLE_2,
|
||||||
"Model_Number": (1, 2),
|
DUMMY_MODEL_CTRL_TABLE,
|
||||||
"Present_Position": (3, 4),
|
MockMotorsBus,
|
||||||
"Goal_Position": (11, 2),
|
)
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_CTRL_TABLE_2 = {
|
|
||||||
"Model_Number": (0, 2),
|
|
||||||
"Firmware_Version": (2, 1),
|
|
||||||
"Present_Position": (3, 4),
|
|
||||||
"Present_Velocity": (7, 4),
|
|
||||||
"Goal_Position": (11, 4),
|
|
||||||
"Goal_Velocity": (15, 4),
|
|
||||||
"Lock": (19, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_CTRL_TABLE = {
|
|
||||||
"model_1": DUMMY_CTRL_TABLE_1,
|
|
||||||
"model_2": DUMMY_CTRL_TABLE_2,
|
|
||||||
"model_3": DUMMY_CTRL_TABLE_2,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_BAUDRATE_TABLE = {
|
|
||||||
0: 1_000_000,
|
|
||||||
1: 500_000,
|
|
||||||
2: 250_000,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_BAUDRATE_TABLE = {
|
|
||||||
"model_1": DUMMY_BAUDRATE_TABLE,
|
|
||||||
"model_2": DUMMY_BAUDRATE_TABLE,
|
|
||||||
"model_3": DUMMY_BAUDRATE_TABLE,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_ENCODING_TABLE = {
|
|
||||||
"Present_Position": 8,
|
|
||||||
"Goal_Position": 10,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_ENCODING_TABLE = {
|
|
||||||
"model_1": DUMMY_ENCODING_TABLE,
|
|
||||||
"model_2": DUMMY_ENCODING_TABLE,
|
|
||||||
"model_3": DUMMY_ENCODING_TABLE,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_NUMBER_TABLE = {
|
|
||||||
"model_1": 1234,
|
|
||||||
"model_2": 5678,
|
|
||||||
"model_3": 5799,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_RESOLUTION_TABLE = {
|
|
||||||
"model_1": 4096,
|
|
||||||
"model_2": 1024,
|
|
||||||
"model_3": 4096,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MockPortHandler:
|
|
||||||
def __init__(self, port_name):
|
|
||||||
self.is_open: bool = False
|
|
||||||
self.baudrate: int
|
|
||||||
self.packet_start_time: float
|
|
||||||
self.packet_timeout: float
|
|
||||||
self.tx_time_per_byte: float
|
|
||||||
self.is_using: bool = False
|
|
||||||
self.port_name: str = port_name
|
|
||||||
self.ser = None
|
|
||||||
|
|
||||||
def openPort(self):
|
|
||||||
self.is_open = True
|
|
||||||
return self.is_open
|
|
||||||
|
|
||||||
def closePort(self):
|
|
||||||
self.is_open = False
|
|
||||||
|
|
||||||
def clearPort(self): ...
|
|
||||||
def setPortName(self, port_name):
|
|
||||||
self.port_name = port_name
|
|
||||||
|
|
||||||
def getPortName(self):
|
|
||||||
return self.port_name
|
|
||||||
|
|
||||||
def setBaudRate(self, baudrate):
|
|
||||||
self.baudrate: baudrate
|
|
||||||
|
|
||||||
def getBaudRate(self):
|
|
||||||
return self.baudrate
|
|
||||||
|
|
||||||
def getBytesAvailable(self): ...
|
|
||||||
def readPort(self, length): ...
|
|
||||||
def writePort(self, packet): ...
|
|
||||||
def setPacketTimeout(self, packet_length): ...
|
|
||||||
def setPacketTimeoutMillis(self, msec): ...
|
|
||||||
def isPacketTimeout(self): ...
|
|
||||||
def getCurrentTime(self): ...
|
|
||||||
def getTimeSinceStart(self): ...
|
|
||||||
def setupPort(self, cflag_baud): ...
|
|
||||||
def getCFlagBaud(self, baudrate): ...
|
|
||||||
|
|
||||||
|
|
||||||
class MockMotorsBus(MotorsBus):
|
|
||||||
available_baudrates = [500_000, 1_000_000]
|
|
||||||
default_timeout = 1000
|
|
||||||
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
|
|
||||||
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
|
|
||||||
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
|
|
||||||
model_number_table = DUMMY_MODEL_NUMBER_TABLE
|
|
||||||
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
|
|
||||||
normalized_data = ["Present_Position", "Goal_Position"]
|
|
||||||
|
|
||||||
def __init__(self, port: str, motors: dict[str, Motor]):
|
|
||||||
super().__init__(port, motors)
|
|
||||||
self.port_handler = MockPortHandler(port)
|
|
||||||
|
|
||||||
def _assert_protocol_is_compatible(self, instruction_name): ...
|
|
||||||
def _handshake(self): ...
|
|
||||||
def _find_single_motor(self, motor, initial_baudrate): ...
|
|
||||||
def configure_motors(self): ...
|
|
||||||
def read_calibration(self): ...
|
|
||||||
def write_calibration(self, calibration_dict): ...
|
|
||||||
def disable_torque(self, motors, num_retry): ...
|
|
||||||
def _disable_torque(self, motor, model, num_retry): ...
|
|
||||||
def enable_torque(self, motors, num_retry): ...
|
|
||||||
def _get_half_turn_homings(self, positions): ...
|
|
||||||
def _encode_sign(self, data_name, ids_values): ...
|
|
||||||
def _decode_sign(self, data_name, ids_values): ...
|
|
||||||
def _split_into_byte_chunks(self, value, length): ...
|
|
||||||
def broadcast_ping(self, num_retry, raise_on_error): ...
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
95
tests/robots/test_so100_follower.py
Normal file
95
tests/robots/test_so100_follower.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.robots.so100_follower import (
|
||||||
|
SO100Follower,
|
||||||
|
SO100FollowerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_bus_mock() -> MagicMock:
|
||||||
|
"""Return a bus mock with just the attributes used by the robot."""
|
||||||
|
bus = MagicMock(name="FeetechBusMock")
|
||||||
|
bus.is_connected = False
|
||||||
|
|
||||||
|
def _connect():
|
||||||
|
bus.is_connected = True
|
||||||
|
|
||||||
|
def _disconnect(_disable=True):
|
||||||
|
bus.is_connected = False
|
||||||
|
|
||||||
|
bus.connect.side_effect = _connect
|
||||||
|
bus.disconnect.side_effect = _disconnect
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _dummy_cm():
|
||||||
|
yield
|
||||||
|
|
||||||
|
bus.torque_disabled.side_effect = _dummy_cm
|
||||||
|
|
||||||
|
return bus
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def follower():
|
||||||
|
bus_mock = _make_bus_mock()
|
||||||
|
|
||||||
|
def _bus_side_effect(*_args, **kwargs):
|
||||||
|
bus_mock.motors = kwargs["motors"]
|
||||||
|
motors_order: list[str] = list(bus_mock.motors)
|
||||||
|
|
||||||
|
bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)}
|
||||||
|
bus_mock.sync_write.return_value = None
|
||||||
|
bus_mock.write.return_value = None
|
||||||
|
bus_mock.disable_torque.return_value = None
|
||||||
|
bus_mock.enable_torque.return_value = None
|
||||||
|
bus_mock.is_calibrated = True
|
||||||
|
return bus_mock
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
|
||||||
|
side_effect=_bus_side_effect,
|
||||||
|
),
|
||||||
|
patch.object(SO100Follower, "configure", lambda self: None),
|
||||||
|
):
|
||||||
|
cfg = SO100FollowerConfig(port="/dev/null")
|
||||||
|
robot = SO100Follower(cfg)
|
||||||
|
yield robot
|
||||||
|
if robot.is_connected:
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_disconnect(follower):
|
||||||
|
assert not follower.is_connected
|
||||||
|
|
||||||
|
follower.connect()
|
||||||
|
assert follower.is_connected
|
||||||
|
|
||||||
|
follower.disconnect()
|
||||||
|
assert not follower.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_observation(follower):
|
||||||
|
follower.connect()
|
||||||
|
obs = follower.get_observation()
|
||||||
|
|
||||||
|
expected_keys = {f"{m}.pos" for m in follower.bus.motors}
|
||||||
|
assert set(obs.keys()) == expected_keys
|
||||||
|
|
||||||
|
for idx, motor in enumerate(follower.bus.motors, 1):
|
||||||
|
assert obs[f"{motor}.pos"] == idx
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_action(follower):
|
||||||
|
follower.connect()
|
||||||
|
|
||||||
|
action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)}
|
||||||
|
returned = follower.send_action(action)
|
||||||
|
|
||||||
|
assert returned == action
|
||||||
|
|
||||||
|
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)}
|
||||||
|
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)
|
||||||
Reference in New Issue
Block a user