Merge branch 'main' into user/adil-zouitine/2025-1-7-port-hil-serl-new
This commit is contained in:
3
tests/artifacts/cameras/image_128x128.png
Normal file
3
tests/artifacts/cameras/image_128x128.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9dc9df05797dc0e7b92edc845caab2e4c37c3cfcabb4ee6339c67212b5baba3b
|
||||
size 38023
|
||||
3
tests/artifacts/cameras/image_160x120.png
Normal file
3
tests/artifacts/cameras/image_160x120.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7e11af87616b83c1cdb30330e951b91e86b51c64a1326e1ba5b4a3fbcdec1a11
|
||||
size 55698
|
||||
3
tests/artifacts/cameras/image_320x180.png
Normal file
3
tests/artifacts/cameras/image_320x180.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b8840fb643afe903191248703b1f95a57faf5812ecd9978ac502ee939646fdb2
|
||||
size 121115
|
||||
3
tests/artifacts/cameras/image_480x270.png
Normal file
3
tests/artifacts/cameras/image_480x270.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f79d14daafb1c0cf2fec5d46ee8029a73fe357402fdd31a7cd4a4794d7319a7c
|
||||
size 260367
|
||||
3
tests/artifacts/cameras/test_rs.bag
Normal file
3
tests/artifacts/cameras/test_rs.bag
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a8d6e64d6cb0e02c94ae125630ee758055bd2e695772c0463a30d63ddc6c5e17
|
||||
size 3520862
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0389a716d51c1c615fb2a3bfa386d89f00b0deca08c4fa21b23e020a939d0213
|
||||
oid sha256:6b1e600768a8771c5fe650e038a1193597e3810f032041b2a0d021e4496381c1
|
||||
size 3686488
|
||||
|
||||
@@ -28,7 +28,7 @@ from lerobot.common.datasets.transforms import (
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
|
||||
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
DATASET_REPO_ID = "lerobot/aloha_static_cups_open"
|
||||
|
||||
|
||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0dc691503e7d90b2086bb408e89a65f772ce5ee6e3562ef8c127bcb09bd90851
|
||||
oid sha256:9d4ebab73eabddc58879a4e770289d19e00a1a4cf2fa5fa33cd3a3246992bc90
|
||||
size 40551392
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
|
||||
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
|
||||
size 5104
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
|
||||
oid sha256:1a7a8b1a457149109f843c32bcbb047d09de2201847b9b79f7501b447f77ecf4
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
|
||||
oid sha256:5e6ce85296b2009e7c2060d336c0429b1c7197d9adb159e7df0ba18003067b36
|
||||
size 68
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
|
||||
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
|
||||
size 33400
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
|
||||
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
|
||||
size 515400
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
|
||||
oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
|
||||
oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd
|
||||
size 68
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
|
||||
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
|
||||
size 33400
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import cache
|
||||
|
||||
import numpy as np
|
||||
|
||||
CAP_V4L2 = 200
|
||||
CAP_DSHOW = 700
|
||||
CAP_AVFOUNDATION = 1200
|
||||
CAP_ANY = -1
|
||||
|
||||
CAP_PROP_FPS = 5
|
||||
CAP_PROP_FRAME_WIDTH = 3
|
||||
CAP_PROP_FRAME_HEIGHT = 4
|
||||
COLOR_RGB2BGR = 4
|
||||
COLOR_BGR2RGB = 4
|
||||
|
||||
ROTATE_90_COUNTERCLOCKWISE = 2
|
||||
ROTATE_90_CLOCKWISE = 0
|
||||
ROTATE_180 = 1
|
||||
|
||||
|
||||
@cache
|
||||
def _generate_image(width: int, height: int):
|
||||
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def cvtColor(color_image, color_conversion): # noqa: N802
|
||||
if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
|
||||
return color_image[:, :, [2, 1, 0]]
|
||||
else:
|
||||
raise NotImplementedError(color_conversion)
|
||||
|
||||
|
||||
def rotate(color_image, rotation):
|
||||
if rotation is None:
|
||||
return color_image
|
||||
elif rotation == ROTATE_90_CLOCKWISE:
|
||||
return np.rot90(color_image, k=1)
|
||||
elif rotation == ROTATE_180:
|
||||
return np.rot90(color_image, k=2)
|
||||
elif rotation == ROTATE_90_COUNTERCLOCKWISE:
|
||||
return np.rot90(color_image, k=3)
|
||||
else:
|
||||
raise NotImplementedError(rotation)
|
||||
|
||||
|
||||
class VideoCapture:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._mock_dict = {
|
||||
CAP_PROP_FPS: 30,
|
||||
CAP_PROP_FRAME_WIDTH: 640,
|
||||
CAP_PROP_FRAME_HEIGHT: 480,
|
||||
}
|
||||
self._is_opened = True
|
||||
|
||||
def isOpened(self): # noqa: N802
|
||||
return self._is_opened
|
||||
|
||||
def set(self, propId: int, value: float) -> bool: # noqa: N803
|
||||
if not self._is_opened:
|
||||
raise RuntimeError("Camera is not opened")
|
||||
self._mock_dict[propId] = value
|
||||
return True
|
||||
|
||||
def get(self, propId: int) -> float: # noqa: N803
|
||||
if not self._is_opened:
|
||||
raise RuntimeError("Camera is not opened")
|
||||
value = self._mock_dict[propId]
|
||||
if value == 0:
|
||||
if propId == CAP_PROP_FRAME_HEIGHT:
|
||||
value = 480
|
||||
elif propId == CAP_PROP_FRAME_WIDTH:
|
||||
value = 640
|
||||
return value
|
||||
|
||||
def read(self):
|
||||
if not self._is_opened:
|
||||
raise RuntimeError("Camera is not opened")
|
||||
h = self.get(CAP_PROP_FRAME_HEIGHT)
|
||||
w = self.get(CAP_PROP_FRAME_WIDTH)
|
||||
ret = True
|
||||
return ret, _generate_image(width=w, height=h)
|
||||
|
||||
def release(self):
|
||||
self._is_opened = False
|
||||
|
||||
def __del__(self):
|
||||
if self._is_opened:
|
||||
self.release()
|
||||
@@ -1,148 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class stream(enum.Enum): # noqa: N801
|
||||
color = 0
|
||||
depth = 1
|
||||
|
||||
|
||||
class format(enum.Enum): # noqa: N801
|
||||
rgb8 = 0
|
||||
z16 = 1
|
||||
|
||||
|
||||
class config: # noqa: N801
|
||||
def enable_device(self, device_id: str):
|
||||
self.device_enabled = device_id
|
||||
|
||||
def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
|
||||
self.stream_type = stream_type
|
||||
# Overwrite default values when possible
|
||||
self.width = 848 if width is None else width
|
||||
self.height = 480 if height is None else height
|
||||
self.color_format = format.rgb8 if color_format is None else color_format
|
||||
self.fps = 30 if fps is None else fps
|
||||
|
||||
|
||||
class RSColorProfile:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def fps(self):
|
||||
return self.config.fps
|
||||
|
||||
def width(self):
|
||||
return self.config.width
|
||||
|
||||
def height(self):
|
||||
return self.config.height
|
||||
|
||||
|
||||
class RSColorStream:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def as_video_stream_profile(self):
|
||||
return RSColorProfile(self.config)
|
||||
|
||||
|
||||
class RSProfile:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_stream(self, color_format):
|
||||
del color_format # unused
|
||||
return RSColorStream(self.config)
|
||||
|
||||
|
||||
class pipeline: # noqa: N801
|
||||
def __init__(self):
|
||||
self.started = False
|
||||
self.config = None
|
||||
|
||||
def start(self, config):
|
||||
self.started = True
|
||||
self.config = config
|
||||
return RSProfile(self.config)
|
||||
|
||||
def stop(self):
|
||||
if not self.started:
|
||||
raise RuntimeError("You need to start the camera before stop.")
|
||||
self.started = False
|
||||
self.config = None
|
||||
|
||||
def wait_for_frames(self, timeout_ms=50000):
|
||||
del timeout_ms # unused
|
||||
return RSFrames(self.config)
|
||||
|
||||
|
||||
class RSFrames:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_color_frame(self):
|
||||
return RSColorFrame(self.config)
|
||||
|
||||
def get_depth_frame(self):
|
||||
return RSDepthFrame(self.config)
|
||||
|
||||
|
||||
class RSColorFrame:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_data(self):
|
||||
data = np.ones((self.config.height, self.config.width, 3), dtype=np.uint8)
|
||||
# Create a difference between rgb and bgr
|
||||
data[:, :, 0] = 2
|
||||
return data
|
||||
|
||||
|
||||
class RSDepthFrame:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_data(self):
|
||||
return np.ones((self.config.height, self.config.width), dtype=np.uint16)
|
||||
|
||||
|
||||
class RSDevice:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_info(self, camera_info) -> str:
|
||||
del camera_info # unused
|
||||
# return fake serial number
|
||||
return "123456789"
|
||||
|
||||
|
||||
class context: # noqa: N801
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def query_devices(self):
|
||||
return [RSDevice()]
|
||||
|
||||
|
||||
class camera_info: # noqa: N801
|
||||
# fake name
|
||||
name = "Intel RealSense D435I"
|
||||
|
||||
def __init__(self, serial_number):
|
||||
del serial_number
|
||||
pass
|
||||
@@ -1,275 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for physical cameras and their mocked versions.
|
||||
If the physical camera is not connected to the computer, or not working,
|
||||
the test will be skipped.
|
||||
|
||||
Example of running a specific test:
|
||||
```bash
|
||||
pytest -sx tests/test_cameras.py::test_camera
|
||||
```
|
||||
|
||||
Example of running test on a real camera connected to the computer:
|
||||
```bash
|
||||
pytest -sx 'tests/test_cameras.py::test_camera[opencv-False]'
|
||||
pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-False]'
|
||||
```
|
||||
|
||||
Example of running test on a mocked version of the camera:
|
||||
```bash
|
||||
pytest -sx 'tests/test_cameras.py::test_camera[opencv-True]'
|
||||
pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-True]'
|
||||
```
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera
|
||||
|
||||
# Maximum absolute difference between two consecutive images recorded by a camera.
|
||||
# This value differs with respect to the camera.
|
||||
MAX_PIXEL_DIFFERENCE = 25
|
||||
|
||||
|
||||
def compute_max_pixel_difference(first_image, second_image):
|
||||
return np.abs(first_image.astype(float) - second_image.astype(float)).max()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
|
||||
@require_camera
|
||||
def test_camera(request, camera_type, mock):
|
||||
"""Test assumes that `camera.read()` returns the same image when called multiple times in a row.
|
||||
So the environment should not change (you shouldnt be in front of the camera) and the camera should not be moving.
|
||||
|
||||
Warning: The tests worked for a macbookpro camera, but I am getting assertion error (`np.allclose(color_image, async_color_image)`)
|
||||
for my iphone camera and my LG monitor camera.
|
||||
"""
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
|
||||
if camera_type == "opencv" and not mock:
|
||||
pytest.skip("TODO(rcadene): fix test for opencv physical camera")
|
||||
|
||||
camera_kwargs = {"camera_type": camera_type, "mock": mock}
|
||||
|
||||
# Test instantiating
|
||||
camera = make_camera(**camera_kwargs)
|
||||
|
||||
# Test reading, async reading, disconnecting before connecting raises an error
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
camera.read()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
camera.async_read()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
camera.disconnect()
|
||||
|
||||
# Test deleting the object without connecting first
|
||||
del camera
|
||||
|
||||
# Test connecting
|
||||
camera = make_camera(**camera_kwargs)
|
||||
camera.connect()
|
||||
assert camera.is_connected
|
||||
assert camera.fps is not None
|
||||
assert camera.capture_width is not None
|
||||
assert camera.capture_height is not None
|
||||
|
||||
# Test connecting twice raises an error
|
||||
with pytest.raises(RobotDeviceAlreadyConnectedError):
|
||||
camera.connect()
|
||||
|
||||
# Test reading from the camera
|
||||
color_image = camera.read()
|
||||
assert isinstance(color_image, np.ndarray)
|
||||
assert color_image.ndim == 3
|
||||
h, w, c = color_image.shape
|
||||
assert c == 3
|
||||
assert w > h
|
||||
|
||||
# Test read and async_read outputs similar images
|
||||
# ...warming up as the first frames can be black
|
||||
for _ in range(30):
|
||||
camera.read()
|
||||
color_image = camera.read()
|
||||
async_color_image = camera.async_read()
|
||||
error_msg = (
|
||||
"max_pixel_difference between read() and async_read()",
|
||||
compute_max_pixel_difference(color_image, async_color_image),
|
||||
)
|
||||
# TODO(rcadene): properly set `rtol`
|
||||
np.testing.assert_allclose(
|
||||
color_image,
|
||||
async_color_image,
|
||||
rtol=1e-5,
|
||||
atol=MAX_PIXEL_DIFFERENCE,
|
||||
err_msg=error_msg,
|
||||
)
|
||||
|
||||
# Test disconnecting
|
||||
camera.disconnect()
|
||||
assert camera.camera is None
|
||||
assert camera.thread is None
|
||||
|
||||
# Test disconnecting with `__del__`
|
||||
camera = make_camera(**camera_kwargs)
|
||||
camera.connect()
|
||||
del camera
|
||||
|
||||
# Test acquiring a bgr image
|
||||
camera = make_camera(**camera_kwargs, color_mode="bgr")
|
||||
camera.connect()
|
||||
assert camera.color_mode == "bgr"
|
||||
bgr_color_image = camera.read()
|
||||
np.testing.assert_allclose(
|
||||
color_image,
|
||||
bgr_color_image[:, :, [2, 1, 0]],
|
||||
rtol=1e-5,
|
||||
atol=MAX_PIXEL_DIFFERENCE,
|
||||
err_msg=error_msg,
|
||||
)
|
||||
del camera
|
||||
|
||||
# Test acquiring a rotated image
|
||||
camera = make_camera(**camera_kwargs)
|
||||
camera.connect()
|
||||
ori_color_image = camera.read()
|
||||
del camera
|
||||
|
||||
for rotation in [None, 90, 180, -90]:
|
||||
camera = make_camera(**camera_kwargs, rotation=rotation)
|
||||
camera.connect()
|
||||
|
||||
if mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
if rotation is None:
|
||||
manual_rot_img = ori_color_image
|
||||
assert camera.rotation is None
|
||||
elif rotation == 90:
|
||||
manual_rot_img = np.rot90(color_image, k=1)
|
||||
assert camera.rotation == cv2.ROTATE_90_CLOCKWISE
|
||||
elif rotation == 180:
|
||||
manual_rot_img = np.rot90(color_image, k=2)
|
||||
assert camera.rotation == cv2.ROTATE_180
|
||||
elif rotation == -90:
|
||||
manual_rot_img = np.rot90(color_image, k=3)
|
||||
assert camera.rotation == cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
|
||||
rot_color_image = camera.read()
|
||||
|
||||
np.testing.assert_allclose(
|
||||
rot_color_image,
|
||||
manual_rot_img,
|
||||
rtol=1e-5,
|
||||
atol=MAX_PIXEL_DIFFERENCE,
|
||||
err_msg=error_msg,
|
||||
)
|
||||
del camera
|
||||
|
||||
# TODO(rcadene): Add a test for a camera that doesnt support fps=60 and raises an OSError
|
||||
# TODO(rcadene): Add a test for a camera that supports fps=60
|
||||
|
||||
# Test width and height can be set
|
||||
camera = make_camera(**camera_kwargs, fps=30, width=1280, height=720)
|
||||
camera.connect()
|
||||
assert camera.fps == 30
|
||||
assert camera.width == 1280
|
||||
assert camera.height == 720
|
||||
color_image = camera.read()
|
||||
h, w, c = color_image.shape
|
||||
assert h == 720
|
||||
assert w == 1280
|
||||
assert c == 3
|
||||
del camera
|
||||
|
||||
# Test not supported width and height raise an error
|
||||
camera = make_camera(**camera_kwargs, fps=30, width=0, height=0)
|
||||
with pytest.raises(OSError):
|
||||
camera.connect()
|
||||
del camera
|
||||
|
||||
|
||||
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
|
||||
@require_camera
|
||||
def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
||||
# TODO(rcadene): refactor
|
||||
if camera_type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
||||
save_images_from_cameras,
|
||||
)
|
||||
|
||||
# Small `record_time_s` to speedup unit tests
|
||||
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
|
||||
@require_camera
|
||||
def test_camera_rotation(request, camera_type, mock):
|
||||
config_kwargs = {
|
||||
"camera_type": camera_type,
|
||||
"mock": mock,
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
"fps": 30,
|
||||
}
|
||||
|
||||
# No rotation.
|
||||
camera = make_camera(**config_kwargs, rotation=None)
|
||||
camera.connect()
|
||||
assert camera.capture_width == 640
|
||||
assert camera.capture_height == 480
|
||||
assert camera.width == 640
|
||||
assert camera.height == 480
|
||||
no_rot_img = camera.read()
|
||||
h, w, c = no_rot_img.shape
|
||||
assert h == 480 and w == 640 and c == 3
|
||||
camera.disconnect()
|
||||
|
||||
# Rotation = 90 (clockwise).
|
||||
camera = make_camera(**config_kwargs, rotation=90)
|
||||
camera.connect()
|
||||
# With a 90° rotation, we expect the metadata dimensions to be swapped.
|
||||
assert camera.capture_width == 640
|
||||
assert camera.capture_height == 480
|
||||
assert camera.width == 480
|
||||
assert camera.height == 640
|
||||
import cv2
|
||||
|
||||
assert camera.rotation == cv2.ROTATE_90_CLOCKWISE
|
||||
rot_img = camera.read()
|
||||
h, w, c = rot_img.shape
|
||||
assert h == 640 and w == 480 and c == 3
|
||||
camera.disconnect()
|
||||
|
||||
# Rotation = 180.
|
||||
camera = make_camera(**config_kwargs, rotation=None)
|
||||
camera.connect()
|
||||
assert camera.capture_width == 640
|
||||
assert camera.capture_height == 480
|
||||
assert camera.width == 640
|
||||
assert camera.height == 480
|
||||
no_rot_img = camera.read()
|
||||
h, w, c = no_rot_img.shape
|
||||
assert h == 480 and w == 640 and c == 3
|
||||
camera.disconnect()
|
||||
190
tests/cameras/test_opencv.py
Normal file
190
tests/cameras/test_opencv.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Example of running a specific test:
|
||||
# ```bash
|
||||
# pytest tests/cameras/test_opencv.py::test_connect
|
||||
# ```
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.cameras.configs import Cv2Rotation
|
||||
from lerobot.common.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
# NOTE(Steven): more tests + assertions?
|
||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras"
|
||||
DEFAULT_PNG_FILE_PATH = TEST_ARTIFACTS_DIR / "image_160x120.png"
|
||||
TEST_IMAGE_SIZES = ["128x128", "160x120", "320x180", "480x270"]
|
||||
TEST_IMAGE_PATHS = [TEST_ARTIFACTS_DIR / f"image_{size}.png" for size in TEST_IMAGE_SIZES]
|
||||
|
||||
|
||||
def test_abc_implementation():
|
||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
||||
config = OpenCVCameraConfig(index_or_path=0)
|
||||
|
||||
_ = OpenCVCamera(config)
|
||||
|
||||
|
||||
def test_connect():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
|
||||
camera.connect(warmup=False)
|
||||
|
||||
assert camera.is_connected
|
||||
|
||||
|
||||
def test_connect_already_connected():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
with pytest.raises(DeviceAlreadyConnectedError):
|
||||
camera.connect(warmup=False)
|
||||
|
||||
|
||||
def test_connect_invalid_camera_path():
|
||||
config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png")
|
||||
camera = OpenCVCamera(config)
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
camera.connect(warmup=False)
|
||||
|
||||
|
||||
def test_invalid_width_connect():
|
||||
config = OpenCVCameraConfig(
|
||||
index_or_path=DEFAULT_PNG_FILE_PATH,
|
||||
width=99999, # Invalid width to trigger error
|
||||
height=480,
|
||||
)
|
||||
camera = OpenCVCamera(config)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
camera.connect(warmup=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
|
||||
def test_read(index_or_path):
|
||||
config = OpenCVCameraConfig(index_or_path=index_or_path)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
img = camera.read()
|
||||
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
|
||||
def test_read_before_connect():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.read()
|
||||
|
||||
|
||||
def test_disconnect():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
camera.disconnect()
|
||||
|
||||
assert not camera.is_connected
|
||||
|
||||
|
||||
def test_disconnect_before_connect():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
|
||||
def test_async_read(index_or_path):
|
||||
config = OpenCVCameraConfig(index_or_path=index_or_path)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
try:
|
||||
img = camera.async_read()
|
||||
|
||||
assert camera.thread is not None
|
||||
assert camera.thread.is_alive()
|
||||
assert isinstance(img, np.ndarray)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends
|
||||
|
||||
|
||||
def test_async_read_timeout():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
camera.async_read(
|
||||
timeout_ms=0
|
||||
) # NOTE(Steven): This is flaky as sdometimes we actually get a frame
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_async_read_before_connect():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.async_read()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
|
||||
@pytest.mark.parametrize(
|
||||
"rotation",
|
||||
[
|
||||
Cv2Rotation.NO_ROTATION,
|
||||
Cv2Rotation.ROTATE_90,
|
||||
Cv2Rotation.ROTATE_180,
|
||||
Cv2Rotation.ROTATE_270,
|
||||
],
|
||||
ids=["no_rot", "rot90", "rot180", "rot270"],
|
||||
)
|
||||
def test_rotation(rotation, index_or_path):
|
||||
filename = Path(index_or_path).name
|
||||
dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png)
|
||||
original_width, original_height = map(int, dimensions.split("x"))
|
||||
|
||||
config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
|
||||
assert camera.width == original_height
|
||||
assert camera.height == original_width
|
||||
assert img.shape[:2] == (original_width, original_height)
|
||||
else:
|
||||
assert camera.width == original_width
|
||||
assert camera.height == original_height
|
||||
assert img.shape[:2] == (original_height, original_width)
|
||||
206
tests/cameras/test_realsense.py
Normal file
206
tests/cameras/test_realsense.py
Normal file
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Example of running a specific test:
|
||||
# ```bash
|
||||
# pytest tests/cameras/test_opencv.py::test_connect
|
||||
# ```
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.cameras.configs import Cv2Rotation
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
pytest.importorskip("pyrealsense2")
|
||||
|
||||
from lerobot.common.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
|
||||
|
||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras"
|
||||
BAG_FILE_PATH = TEST_ARTIFACTS_DIR / "test_rs.bag"
|
||||
|
||||
# NOTE(Steven): For some reason these tests take ~20sec in macOS but only ~2sec in Linux.
|
||||
|
||||
|
||||
def mock_rs_config_enable_device_from_file(rs_config_instance, _sn):
|
||||
return rs_config_instance.enable_device_from_file(str(BAG_FILE_PATH), repeat_playback=True)
|
||||
|
||||
|
||||
def mock_rs_config_enable_device_bad_file(rs_config_instance, _sn):
|
||||
return rs_config_instance.enable_device_from_file("non_existent_file.bag", repeat_playback=True)
|
||||
|
||||
|
||||
@pytest.fixture(name="patch_realsense", autouse=True)
|
||||
def fixture_patch_realsense():
|
||||
"""Automatically mock pyrealsense2.config.enable_device for all tests."""
|
||||
with patch(
|
||||
"pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def test_abc_implementation():
|
||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
_ = RealSenseCamera(config)
|
||||
|
||||
|
||||
def test_connect():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
camera = RealSenseCamera(config)
|
||||
|
||||
camera.connect(warmup=False)
|
||||
assert camera.is_connected
|
||||
|
||||
|
||||
def test_connect_already_connected():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
with pytest.raises(DeviceAlreadyConnectedError):
|
||||
camera.connect(warmup=False)
|
||||
|
||||
|
||||
def test_connect_invalid_camera_path(patch_realsense):
|
||||
patch_realsense.side_effect = mock_rs_config_enable_device_bad_file
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
camera = RealSenseCamera(config)
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
camera.connect(warmup=False)
|
||||
|
||||
|
||||
def test_invalid_width_connect():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=99999, height=480, fps=30)
|
||||
camera = RealSenseCamera(config)
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
camera.connect(warmup=False)
|
||||
|
||||
|
||||
def test_read():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
|
||||
def test_read_depth():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, use_depth=True)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
img = camera.read_depth(timeout_ms=1000) # NOTE(Steven): Reading depth takes longer
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
|
||||
def test_read_before_connect():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
camera = RealSenseCamera(config)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.read()
|
||||
|
||||
|
||||
def test_disconnect():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
camera.disconnect()
|
||||
|
||||
assert not camera.is_connected
|
||||
|
||||
|
||||
def test_disconnect_before_connect():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
camera = RealSenseCamera(config)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_async_read():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
try:
|
||||
img = camera.async_read()
|
||||
|
||||
assert camera.thread is not None
|
||||
assert camera.thread.is_alive()
|
||||
assert isinstance(img, np.ndarray)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends
|
||||
|
||||
|
||||
def test_async_read_timeout():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
camera.async_read(
|
||||
timeout_ms=0
|
||||
) # NOTE(Steven): This is flaky as sdometimes we actually get a frame
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_async_read_before_connect():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
camera = RealSenseCamera(config)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.async_read()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rotation",
|
||||
[
|
||||
Cv2Rotation.NO_ROTATION,
|
||||
Cv2Rotation.ROTATE_90,
|
||||
Cv2Rotation.ROTATE_180,
|
||||
Cv2Rotation.ROTATE_270,
|
||||
],
|
||||
ids=["no_rot", "rot90", "rot180", "rot270"],
|
||||
)
|
||||
def test_rotation(rotation):
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
|
||||
assert camera.width == 480
|
||||
assert camera.height == 640
|
||||
assert img.shape[:2] == (640, 480)
|
||||
else:
|
||||
assert camera.width == 640
|
||||
assert camera.height == 480
|
||||
assert img.shape[:2] == (480, 640)
|
||||
@@ -21,9 +21,8 @@ import pytest
|
||||
import torch
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from tests.utils import DEVICE, make_camera, make_motors_bus
|
||||
from lerobot import available_cameras
|
||||
from tests.utils import DEVICE, make_camera
|
||||
|
||||
# Import fixture modules as plugins
|
||||
pytest_plugins = [
|
||||
@@ -66,21 +65,11 @@ def _check_component_availability(component_type, available_components, make_com
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_robot_available(robot_type):
|
||||
return _check_component_availability(robot_type, available_robots, make_robot)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_camera_available(camera_type):
|
||||
return _check_component_availability(camera_type, available_cameras, make_camera)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_motor_available(motor_type):
|
||||
return _check_component_availability(motor_type, available_motors, make_motors_bus)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_builtins_input(monkeypatch):
|
||||
def print_text(text=None):
|
||||
|
||||
@@ -72,7 +72,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
dataset_create = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
|
||||
)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
@@ -127,7 +129,9 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||
ValueError,
|
||||
match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n",
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
|
||||
dataset.add_frame(
|
||||
{"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -137,7 +141,9 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||
ValueError,
|
||||
match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n",
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
|
||||
dataset.add_frame(
|
||||
{"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -145,7 +151,9 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
||||
match=re.escape(
|
||||
"The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"
|
||||
),
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
|
||||
@@ -167,7 +175,9 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
||||
match=re.escape(
|
||||
"The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"
|
||||
),
|
||||
):
|
||||
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
|
||||
|
||||
@@ -461,7 +471,9 @@ def test_flatten_unflatten_dict():
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), (
|
||||
f"{original_d} != {d}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -515,7 +527,13 @@ def test_backward_compatibility(repo_id):
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
i = int(
|
||||
(
|
||||
dataset.episode_data_index["to"][0].item()
|
||||
- dataset.episode_data_index["from"][0].item()
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
@@ -257,7 +258,14 @@ def test_backward_compatibility_single_transforms(
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
@pytest.mark.skipif(
|
||||
version.parse(torch.__version__) < version.parse("2.7.0"),
|
||||
reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior",
|
||||
)
|
||||
def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
||||
# NOTE: PyTorch versions have different randomness, it might break this test.
|
||||
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
|
||||
|
||||
cfg = ImageTransformsConfig(enable=True)
|
||||
default_tf = ImageTransforms(cfg)
|
||||
|
||||
|
||||
580
tests/mocks/mock_dynamixel.py
Normal file
580
tests/mocks/mock_dynamixel.py
Normal file
@@ -0,0 +1,580 @@
|
||||
import abc
|
||||
from typing import Callable
|
||||
|
||||
import dynamixel_sdk as dxl
|
||||
import serial
|
||||
from mock_serial.mock_serial import MockSerial
|
||||
|
||||
from lerobot.common.motors.dynamixel.dynamixel import _split_into_byte_chunks
|
||||
|
||||
from .mock_serial_patch import WaitableStub
|
||||
|
||||
# https://emanual.robotis.com/docs/en/dxl/crc/
|
||||
DXL_CRC_TABLE = [
|
||||
0x0000, 0x8005, 0x800F, 0x000A, 0x801B, 0x001E, 0x0014, 0x8011,
|
||||
0x8033, 0x0036, 0x003C, 0x8039, 0x0028, 0x802D, 0x8027, 0x0022,
|
||||
0x8063, 0x0066, 0x006C, 0x8069, 0x0078, 0x807D, 0x8077, 0x0072,
|
||||
0x0050, 0x8055, 0x805F, 0x005A, 0x804B, 0x004E, 0x0044, 0x8041,
|
||||
0x80C3, 0x00C6, 0x00CC, 0x80C9, 0x00D8, 0x80DD, 0x80D7, 0x00D2,
|
||||
0x00F0, 0x80F5, 0x80FF, 0x00FA, 0x80EB, 0x00EE, 0x00E4, 0x80E1,
|
||||
0x00A0, 0x80A5, 0x80AF, 0x00AA, 0x80BB, 0x00BE, 0x00B4, 0x80B1,
|
||||
0x8093, 0x0096, 0x009C, 0x8099, 0x0088, 0x808D, 0x8087, 0x0082,
|
||||
0x8183, 0x0186, 0x018C, 0x8189, 0x0198, 0x819D, 0x8197, 0x0192,
|
||||
0x01B0, 0x81B5, 0x81BF, 0x01BA, 0x81AB, 0x01AE, 0x01A4, 0x81A1,
|
||||
0x01E0, 0x81E5, 0x81EF, 0x01EA, 0x81FB, 0x01FE, 0x01F4, 0x81F1,
|
||||
0x81D3, 0x01D6, 0x01DC, 0x81D9, 0x01C8, 0x81CD, 0x81C7, 0x01C2,
|
||||
0x0140, 0x8145, 0x814F, 0x014A, 0x815B, 0x015E, 0x0154, 0x8151,
|
||||
0x8173, 0x0176, 0x017C, 0x8179, 0x0168, 0x816D, 0x8167, 0x0162,
|
||||
0x8123, 0x0126, 0x012C, 0x8129, 0x0138, 0x813D, 0x8137, 0x0132,
|
||||
0x0110, 0x8115, 0x811F, 0x011A, 0x810B, 0x010E, 0x0104, 0x8101,
|
||||
0x8303, 0x0306, 0x030C, 0x8309, 0x0318, 0x831D, 0x8317, 0x0312,
|
||||
0x0330, 0x8335, 0x833F, 0x033A, 0x832B, 0x032E, 0x0324, 0x8321,
|
||||
0x0360, 0x8365, 0x836F, 0x036A, 0x837B, 0x037E, 0x0374, 0x8371,
|
||||
0x8353, 0x0356, 0x035C, 0x8359, 0x0348, 0x834D, 0x8347, 0x0342,
|
||||
0x03C0, 0x83C5, 0x83CF, 0x03CA, 0x83DB, 0x03DE, 0x03D4, 0x83D1,
|
||||
0x83F3, 0x03F6, 0x03FC, 0x83F9, 0x03E8, 0x83ED, 0x83E7, 0x03E2,
|
||||
0x83A3, 0x03A6, 0x03AC, 0x83A9, 0x03B8, 0x83BD, 0x83B7, 0x03B2,
|
||||
0x0390, 0x8395, 0x839F, 0x039A, 0x838B, 0x038E, 0x0384, 0x8381,
|
||||
0x0280, 0x8285, 0x828F, 0x028A, 0x829B, 0x029E, 0x0294, 0x8291,
|
||||
0x82B3, 0x02B6, 0x02BC, 0x82B9, 0x02A8, 0x82AD, 0x82A7, 0x02A2,
|
||||
0x82E3, 0x02E6, 0x02EC, 0x82E9, 0x02F8, 0x82FD, 0x82F7, 0x02F2,
|
||||
0x02D0, 0x82D5, 0x82DF, 0x02DA, 0x82CB, 0x02CE, 0x02C4, 0x82C1,
|
||||
0x8243, 0x0246, 0x024C, 0x8249, 0x0258, 0x825D, 0x8257, 0x0252,
|
||||
0x0270, 0x8275, 0x827F, 0x027A, 0x826B, 0x026E, 0x0264, 0x8261,
|
||||
0x0220, 0x8225, 0x822F, 0x022A, 0x823B, 0x023E, 0x0234, 0x8231,
|
||||
0x8213, 0x0216, 0x021C, 0x8219, 0x0208, 0x820D, 0x8207, 0x0202
|
||||
] # fmt: skip
|
||||
|
||||
|
||||
class MockDynamixelPacketv2(abc.ABC):
|
||||
@classmethod
|
||||
def build(cls, dxl_id: int, params: list[int], length: list[int], *args, **kwargs) -> bytes:
|
||||
packet = cls._build(dxl_id, params, length, *args, **kwargs)
|
||||
packet = cls._add_stuffing(packet)
|
||||
packet = cls._add_crc(packet)
|
||||
return bytes(packet)
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def _build(cls, dxl_id: int, params: list[int], length: int, *args, **kwargs) -> list[int]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _add_stuffing(packet: list[int]) -> list[int]:
|
||||
"""
|
||||
Byte stuffing is a method of adding additional data to generated instruction packets to ensure that
|
||||
the packets are processed successfully. When the byte pattern "0xFF 0xFF 0xFD" appears in a packet,
|
||||
byte stuffing adds 0xFD to the end of the pattern to convert it to “0xFF 0xFF 0xFD 0xFD” to ensure
|
||||
that it is not interpreted as the header at the start of another packet.
|
||||
|
||||
Source: https://emanual.robotis.com/docs/en/dxl/protocol2/#transmission-process
|
||||
|
||||
Args:
|
||||
packet (list[int]): The raw packet without stuffing.
|
||||
|
||||
Returns:
|
||||
list[int]: The packet stuffed if it contained a "0xFF 0xFF 0xFD" byte sequence in its data bytes.
|
||||
"""
|
||||
packet_length_in = dxl.DXL_MAKEWORD(packet[dxl.PKT_LENGTH_L], packet[dxl.PKT_LENGTH_H])
|
||||
packet_length_out = packet_length_in
|
||||
|
||||
temp = [0] * dxl.TXPACKET_MAX_LEN
|
||||
|
||||
# FF FF FD XX ID LEN_L LEN_H
|
||||
temp[dxl.PKT_HEADER0 : dxl.PKT_HEADER0 + dxl.PKT_LENGTH_H + 1] = packet[
|
||||
dxl.PKT_HEADER0 : dxl.PKT_HEADER0 + dxl.PKT_LENGTH_H + 1
|
||||
]
|
||||
|
||||
index = dxl.PKT_INSTRUCTION
|
||||
|
||||
for i in range(0, packet_length_in - 2): # except CRC
|
||||
temp[index] = packet[i + dxl.PKT_INSTRUCTION]
|
||||
index = index + 1
|
||||
if (
|
||||
packet[i + dxl.PKT_INSTRUCTION] == 0xFD
|
||||
and packet[i + dxl.PKT_INSTRUCTION - 1] == 0xFF
|
||||
and packet[i + dxl.PKT_INSTRUCTION - 2] == 0xFF
|
||||
):
|
||||
# FF FF FD
|
||||
temp[index] = 0xFD
|
||||
index = index + 1
|
||||
packet_length_out = packet_length_out + 1
|
||||
|
||||
temp[index] = packet[dxl.PKT_INSTRUCTION + packet_length_in - 2]
|
||||
temp[index + 1] = packet[dxl.PKT_INSTRUCTION + packet_length_in - 1]
|
||||
index = index + 2
|
||||
|
||||
if packet_length_in != packet_length_out:
|
||||
packet = [0] * index
|
||||
|
||||
packet[0:index] = temp[0:index]
|
||||
|
||||
packet[dxl.PKT_LENGTH_L] = dxl.DXL_LOBYTE(packet_length_out)
|
||||
packet[dxl.PKT_LENGTH_H] = dxl.DXL_HIBYTE(packet_length_out)
|
||||
|
||||
return packet
|
||||
|
||||
@staticmethod
|
||||
def _add_crc(packet: list[int]) -> list[int]:
|
||||
"""Computes and add CRC to the packet.
|
||||
|
||||
https://emanual.robotis.com/docs/en/dxl/crc/
|
||||
https://en.wikipedia.org/wiki/Cyclic_redundancy_check
|
||||
|
||||
Args:
|
||||
packet (list[int]): The raw packet without CRC (but with placeholders for it).
|
||||
|
||||
Returns:
|
||||
list[int]: The raw packet with a valid CRC.
|
||||
"""
|
||||
crc = 0
|
||||
for j in range(len(packet) - 2):
|
||||
i = ((crc >> 8) ^ packet[j]) & 0xFF
|
||||
crc = ((crc << 8) ^ DXL_CRC_TABLE[i]) & 0xFFFF
|
||||
|
||||
packet[-2] = dxl.DXL_LOBYTE(crc)
|
||||
packet[-1] = dxl.DXL_HIBYTE(crc)
|
||||
|
||||
return packet
|
||||
|
||||
|
||||
class MockInstructionPacket(MockDynamixelPacketv2):
|
||||
"""
|
||||
Helper class to build valid Dynamixel Protocol 2.0 Instruction Packets.
|
||||
|
||||
Protocol 2.0 Instruction Packet structure
|
||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#instruction-packet
|
||||
|
||||
| Header | Packet ID | Length | Instruction | Params | CRC |
|
||||
| ------------------- | --------- | ----------- | ----------- | ----------------- | ----------- |
|
||||
| 0xFF 0xFF 0xFD 0x00 | ID | Len_L Len_H | Instr | Param 1 … Param N | CRC_L CRC_H |
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _build(cls, dxl_id: int, params: list[int], length: int, instruction: int) -> list[int]:
|
||||
length = len(params) + 3
|
||||
return [
|
||||
0xFF, 0xFF, 0xFD, 0x00, # header
|
||||
dxl_id, # servo id
|
||||
dxl.DXL_LOBYTE(length), # length_l
|
||||
dxl.DXL_HIBYTE(length), # length_h
|
||||
instruction, # instruction type
|
||||
*params, # data bytes
|
||||
0x00, 0x00 # placeholder for CRC
|
||||
] # fmt: skip
|
||||
|
||||
@classmethod
|
||||
def ping(
|
||||
cls,
|
||||
dxl_id: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Ping" broadcast instruction.
|
||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01
|
||||
|
||||
No parameters required.
|
||||
"""
|
||||
return cls.build(dxl_id=dxl_id, params=[], length=3, instruction=dxl.INST_PING)
|
||||
|
||||
@classmethod
|
||||
def read(
|
||||
cls,
|
||||
dxl_id: int,
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Read" instruction.
|
||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
|
||||
|
||||
The parameters for Read (Protocol 2.0) are:
|
||||
param[0] = start_address L
|
||||
param[1] = start_address H
|
||||
param[2] = data_length L
|
||||
param[3] = data_length H
|
||||
|
||||
And 'length' = data_length + 5, where:
|
||||
+1 is for instruction byte,
|
||||
+2 is for the length bytes,
|
||||
+2 is for the CRC at the end.
|
||||
"""
|
||||
params = [
|
||||
dxl.DXL_LOBYTE(start_address),
|
||||
dxl.DXL_HIBYTE(start_address),
|
||||
dxl.DXL_LOBYTE(data_length),
|
||||
dxl.DXL_HIBYTE(data_length),
|
||||
]
|
||||
length = len(params) + 3
|
||||
# length = data_length + 5
|
||||
return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_READ)
|
||||
|
||||
@classmethod
|
||||
def write(
|
||||
cls,
|
||||
dxl_id: int,
|
||||
value: int,
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Write" instruction.
|
||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#write-0x03
|
||||
|
||||
The parameters for Write (Protocol 2.0) are:
|
||||
param[0] = start_address L
|
||||
param[1] = start_address H
|
||||
param[2] = 1st Byte
|
||||
param[3] = 2nd Byte
|
||||
...
|
||||
param[1+X] = X-th Byte
|
||||
|
||||
And 'length' = data_length + 5, where:
|
||||
+1 is for instruction byte,
|
||||
+2 is for the length bytes,
|
||||
+2 is for the CRC at the end.
|
||||
"""
|
||||
data = _split_into_byte_chunks(value, data_length)
|
||||
params = [
|
||||
dxl.DXL_LOBYTE(start_address),
|
||||
dxl.DXL_HIBYTE(start_address),
|
||||
*data,
|
||||
]
|
||||
length = data_length + 5
|
||||
return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_WRITE)
|
||||
|
||||
@classmethod
|
||||
def sync_read(
|
||||
cls,
|
||||
dxl_ids: list[int],
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Sync_Read" broadcast instruction.
|
||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-read-0x82
|
||||
|
||||
The parameters for Sync_Read (Protocol 2.0) are:
|
||||
param[0] = start_address L
|
||||
param[1] = start_address H
|
||||
param[2] = data_length L
|
||||
param[3] = data_length H
|
||||
param[4+] = motor IDs to read from
|
||||
|
||||
And 'length' = (number_of_params + 7), where:
|
||||
+1 is for instruction byte,
|
||||
+2 is for the address bytes,
|
||||
+2 is for the length bytes,
|
||||
+2 is for the CRC at the end.
|
||||
"""
|
||||
params = [
|
||||
dxl.DXL_LOBYTE(start_address),
|
||||
dxl.DXL_HIBYTE(start_address),
|
||||
dxl.DXL_LOBYTE(data_length),
|
||||
dxl.DXL_HIBYTE(data_length),
|
||||
*dxl_ids,
|
||||
]
|
||||
length = len(dxl_ids) + 7
|
||||
return cls.build(
|
||||
dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_READ
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sync_write(
|
||||
cls,
|
||||
ids_values: dict[int],
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Sync_Write" broadcast instruction.
|
||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-write-0x83
|
||||
|
||||
The parameters for Sync_Write (Protocol 2.0) are:
|
||||
param[0] = start_address L
|
||||
param[1] = start_address H
|
||||
param[2] = data_length L
|
||||
param[3] = data_length H
|
||||
param[5] = [1st motor] ID
|
||||
param[5+1] = [1st motor] 1st Byte
|
||||
param[5+2] = [1st motor] 2nd Byte
|
||||
...
|
||||
param[5+X] = [1st motor] X-th Byte
|
||||
param[6] = [2nd motor] ID
|
||||
param[6+1] = [2nd motor] 1st Byte
|
||||
param[6+2] = [2nd motor] 2nd Byte
|
||||
...
|
||||
param[6+X] = [2nd motor] X-th Byte
|
||||
|
||||
And 'length' = ((number_of_params * 1 + data_length) + 7), where:
|
||||
+1 is for instruction byte,
|
||||
+2 is for the address bytes,
|
||||
+2 is for the length bytes,
|
||||
+2 is for the CRC at the end.
|
||||
"""
|
||||
data = []
|
||||
for id_, value in ids_values.items():
|
||||
split_value = _split_into_byte_chunks(value, data_length)
|
||||
data += [id_, *split_value]
|
||||
params = [
|
||||
dxl.DXL_LOBYTE(start_address),
|
||||
dxl.DXL_HIBYTE(start_address),
|
||||
dxl.DXL_LOBYTE(data_length),
|
||||
dxl.DXL_HIBYTE(data_length),
|
||||
*data,
|
||||
]
|
||||
length = len(ids_values) * (1 + data_length) + 7
|
||||
return cls.build(
|
||||
dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_WRITE
|
||||
)
|
||||
|
||||
|
||||
class MockStatusPacket(MockDynamixelPacketv2):
|
||||
"""
|
||||
Helper class to build valid Dynamixel Protocol 2.0 Status Packets.
|
||||
|
||||
Protocol 2.0 Status Packet structure
|
||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#status-packet
|
||||
|
||||
| Header | Packet ID | Length | Instruction | Error | Params | CRC |
|
||||
| ------------------- | --------- | ----------- | ----------- | ----- | ----------------- | ----------- |
|
||||
| 0xFF 0xFF 0xFD 0x00 | ID | Len_L Len_H | 0x55 | Err | Param 1 … Param N | CRC_L CRC_H |
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _build(cls, dxl_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
|
||||
return [
|
||||
0xFF, 0xFF, 0xFD, 0x00, # header
|
||||
dxl_id, # servo id
|
||||
dxl.DXL_LOBYTE(length), # length_l
|
||||
dxl.DXL_HIBYTE(length), # length_h
|
||||
0x55, # instruction = 'status'
|
||||
error, # error
|
||||
*params, # data bytes
|
||||
0x00, 0x00 # placeholder for CRC
|
||||
] # fmt: skip
|
||||
|
||||
@classmethod
|
||||
def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50, error: int = 0) -> bytes:
|
||||
"""
|
||||
Builds a 'Ping' status packet.
|
||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01
|
||||
|
||||
Args:
|
||||
dxl_id (int): ID of the servo responding.
|
||||
model_nb (int, optional): Desired 'model number' to be returned in the packet. Defaults to 1190
|
||||
which corresponds to a XL330-M077-T.
|
||||
firm_ver (int, optional): Desired 'firmware version' to be returned in the packet.
|
||||
Defaults to 50.
|
||||
|
||||
Returns:
|
||||
bytes: The raw 'Ping' status packet ready to be sent through serial.
|
||||
"""
|
||||
params = [dxl.DXL_LOBYTE(model_nb), dxl.DXL_HIBYTE(model_nb), firm_ver]
|
||||
length = 7
|
||||
return cls.build(dxl_id, params=params, length=length, error=error)
|
||||
|
||||
@classmethod
|
||||
def read(cls, dxl_id: int, value: int, param_length: int, error: int = 0) -> bytes:
|
||||
"""
|
||||
Builds a 'Read' status packet (also works for 'Sync Read')
|
||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
|
||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-read-0x82
|
||||
|
||||
Args:
|
||||
dxl_id (int): ID of the servo responding.
|
||||
value (int): Desired value to be returned in the packet.
|
||||
param_length (int): The address length as reported in the control table.
|
||||
|
||||
Returns:
|
||||
bytes: The raw 'Present_Position' status packet ready to be sent through serial.
|
||||
"""
|
||||
params = _split_into_byte_chunks(value, param_length)
|
||||
length = param_length + 4
|
||||
return cls.build(dxl_id, params=params, length=length, error=error)
|
||||
|
||||
|
||||
class MockPortHandler(dxl.PortHandler):
|
||||
"""
|
||||
This class overwrite the 'setupPort' method of the Dynamixel PortHandler because it can specify
|
||||
baudrates that are not supported with a serial port on MacOS.
|
||||
"""
|
||||
|
||||
def setupPort(self, cflag_baud): # noqa: N802
|
||||
if self.is_open:
|
||||
self.closePort()
|
||||
|
||||
self.ser = serial.Serial(
|
||||
port=self.port_name,
|
||||
# baudrate=self.baudrate, <- This will fail on MacOS
|
||||
# parity = serial.PARITY_ODD,
|
||||
# stopbits = serial.STOPBITS_TWO,
|
||||
bytesize=serial.EIGHTBITS,
|
||||
timeout=0,
|
||||
)
|
||||
self.is_open = True
|
||||
self.ser.reset_input_buffer()
|
||||
self.tx_time_per_byte = (1000.0 / self.baudrate) * 10.0
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class MockMotors(MockSerial):
|
||||
"""
|
||||
This class will simulate physical motors by responding with valid status packets upon receiving some
|
||||
instruction packets. It is meant to test MotorsBus classes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def stubs(self) -> dict[str, WaitableStub]:
|
||||
return super().stubs
|
||||
|
||||
def stub(self, *, name=None, **kwargs):
|
||||
new_stub = WaitableStub(**kwargs)
|
||||
self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub
|
||||
return new_stub
|
||||
|
||||
def build_broadcast_ping_stub(
|
||||
self, ids_models: dict[int, list[int]] | None = None, num_invalid_try: int = 0
|
||||
) -> str:
|
||||
ping_request = MockInstructionPacket.ping(dxl.BROADCAST_ID)
|
||||
return_packets = b"".join(MockStatusPacket.ping(id_, model) for id_, model in ids_models.items())
|
||||
ping_response = self._build_send_fn(return_packets, num_invalid_try)
|
||||
|
||||
stub_name = "Ping_" + "_".join([str(id_) for id_ in ids_models])
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=ping_request,
|
||||
send_fn=ping_response,
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_ping_stub(
|
||||
self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0, error: int = 0
|
||||
) -> str:
|
||||
ping_request = MockInstructionPacket.ping(dxl_id)
|
||||
return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver, error)
|
||||
ping_response = self._build_send_fn(return_packet, num_invalid_try)
|
||||
stub_name = f"Ping_{dxl_id}"
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=ping_request,
|
||||
send_fn=ping_response,
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_read_stub(
|
||||
self,
|
||||
address: int,
|
||||
length: int,
|
||||
dxl_id: int,
|
||||
value: int,
|
||||
reply: bool = True,
|
||||
error: int = 0,
|
||||
num_invalid_try: int = 0,
|
||||
) -> str:
|
||||
read_request = MockInstructionPacket.read(dxl_id, address, length)
|
||||
return_packet = MockStatusPacket.read(dxl_id, value, length, error) if reply else b""
|
||||
read_response = self._build_send_fn(return_packet, num_invalid_try)
|
||||
stub_name = f"Read_{address}_{length}_{dxl_id}_{value}_{error}"
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=read_request,
|
||||
send_fn=read_response,
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_write_stub(
|
||||
self,
|
||||
address: int,
|
||||
length: int,
|
||||
dxl_id: int,
|
||||
value: int,
|
||||
reply: bool = True,
|
||||
error: int = 0,
|
||||
num_invalid_try: int = 0,
|
||||
) -> str:
|
||||
sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length)
|
||||
return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error) if reply else b""
|
||||
stub_name = f"Write_{address}_{length}_{dxl_id}"
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=sync_read_request,
|
||||
send_fn=self._build_send_fn(return_packet, num_invalid_try),
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_sync_read_stub(
|
||||
self,
|
||||
address: int,
|
||||
length: int,
|
||||
ids_values: dict[int, int],
|
||||
reply: bool = True,
|
||||
num_invalid_try: int = 0,
|
||||
) -> str:
|
||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
||||
return_packets = (
|
||||
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
|
||||
if reply
|
||||
else b""
|
||||
)
|
||||
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
|
||||
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=sync_read_request,
|
||||
send_fn=sync_read_response,
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_sequential_sync_read_stub(
|
||||
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
|
||||
) -> str:
|
||||
sequence_length = len(next(iter(ids_values.values())))
|
||||
assert all(len(positions) == sequence_length for positions in ids_values.values())
|
||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
||||
sequential_packets = []
|
||||
for count in range(sequence_length):
|
||||
return_packets = b"".join(
|
||||
MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items()
|
||||
)
|
||||
sequential_packets.append(return_packets)
|
||||
|
||||
sync_read_response = self._build_sequential_send_fn(sequential_packets)
|
||||
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=sync_read_request,
|
||||
send_fn=sync_read_response,
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_sync_write_stub(
|
||||
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
|
||||
) -> str:
|
||||
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
|
||||
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=sync_read_request,
|
||||
send_fn=self._build_send_fn(b"", num_invalid_try),
|
||||
)
|
||||
return stub_name
|
||||
|
||||
@staticmethod
|
||||
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
|
||||
def send_fn(_call_count: int) -> bytes:
|
||||
if num_invalid_try >= _call_count:
|
||||
return b""
|
||||
return packet
|
||||
|
||||
return send_fn
|
||||
|
||||
@staticmethod
|
||||
def _build_sequential_send_fn(packets: list[bytes]) -> Callable[[int], bytes]:
|
||||
def send_fn(_call_count: int) -> bytes:
|
||||
return packets[_call_count - 1]
|
||||
|
||||
return send_fn
|
||||
428
tests/mocks/mock_feetech.py
Normal file
428
tests/mocks/mock_feetech.py
Normal file
@@ -0,0 +1,428 @@
|
||||
import abc
|
||||
from typing import Callable
|
||||
|
||||
import scservo_sdk as scs
|
||||
import serial
|
||||
from mock_serial import MockSerial
|
||||
|
||||
from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout
|
||||
|
||||
from .mock_serial_patch import WaitableStub
|
||||
|
||||
|
||||
class MockFeetechPacket(abc.ABC):
|
||||
@classmethod
|
||||
def build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> bytes:
|
||||
packet = cls._build(scs_id, params, length, *args, **kwargs)
|
||||
packet = cls._add_checksum(packet)
|
||||
return bytes(packet)
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def _build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> list[int]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _add_checksum(packet: list[int]) -> list[int]:
|
||||
checksum = 0
|
||||
for id_ in range(2, len(packet) - 1): # except header & checksum
|
||||
checksum += packet[id_]
|
||||
|
||||
packet[-1] = ~checksum & 0xFF
|
||||
|
||||
return packet
|
||||
|
||||
|
||||
class MockInstructionPacket(MockFeetechPacket):
|
||||
"""
|
||||
Helper class to build valid Feetech Instruction Packets.
|
||||
|
||||
Instruction Packet structure
|
||||
(from https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf)
|
||||
|
||||
| Header | Packet ID | Length | Instruction | Params | Checksum |
|
||||
| --------- | --------- | ------ | ----------- | ----------------- | -------- |
|
||||
| 0xFF 0xFF | ID | Len | Instr | Param 1 … Param N | Sum |
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _build(cls, scs_id: int, params: list[int], length: int, instruction: int) -> list[int]:
|
||||
return [
|
||||
0xFF, 0xFF, # header
|
||||
scs_id, # servo id
|
||||
length, # length
|
||||
instruction, # instruction type
|
||||
*params, # data bytes
|
||||
0x00, # placeholder for checksum
|
||||
] # fmt: skip
|
||||
|
||||
@classmethod
|
||||
def ping(
|
||||
cls,
|
||||
scs_id: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Ping" broadcast instruction.
|
||||
|
||||
No parameters required.
|
||||
"""
|
||||
return cls.build(scs_id=scs_id, params=[], length=2, instruction=scs.INST_PING)
|
||||
|
||||
@classmethod
|
||||
def read(
|
||||
cls,
|
||||
scs_id: int,
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Read" instruction.
|
||||
|
||||
The parameters for Read are:
|
||||
param[0] = start_address
|
||||
param[1] = data_length
|
||||
|
||||
And 'length' = 4, where:
|
||||
+1 is for instruction byte,
|
||||
+1 is for the address byte,
|
||||
+1 is for the length bytes,
|
||||
+1 is for the checksum at the end.
|
||||
"""
|
||||
params = [start_address, data_length]
|
||||
length = 4
|
||||
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_READ)
|
||||
|
||||
@classmethod
|
||||
def write(
|
||||
cls,
|
||||
scs_id: int,
|
||||
value: int,
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Write" instruction.
|
||||
|
||||
The parameters for Write are:
|
||||
param[0] = start_address L
|
||||
param[1] = start_address H
|
||||
param[2] = 1st Byte
|
||||
param[3] = 2nd Byte
|
||||
...
|
||||
param[1+X] = X-th Byte
|
||||
|
||||
And 'length' = data_length + 3, where:
|
||||
+1 is for instruction byte,
|
||||
+1 is for the length bytes,
|
||||
+1 is for the checksum at the end.
|
||||
"""
|
||||
data = _split_into_byte_chunks(value, data_length)
|
||||
params = [start_address, *data]
|
||||
length = data_length + 3
|
||||
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_WRITE)
|
||||
|
||||
@classmethod
|
||||
def sync_read(
|
||||
cls,
|
||||
scs_ids: list[int],
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Sync_Read" broadcast instruction.
|
||||
|
||||
The parameters for Sync Read are:
|
||||
param[0] = start_address
|
||||
param[1] = data_length
|
||||
param[2+] = motor IDs to read from
|
||||
|
||||
And 'length' = (number_of_params + 4), where:
|
||||
+1 is for instruction byte,
|
||||
+1 is for the address byte,
|
||||
+1 is for the length bytes,
|
||||
+1 is for the checksum at the end.
|
||||
"""
|
||||
params = [start_address, data_length, *scs_ids]
|
||||
length = len(scs_ids) + 4
|
||||
return cls.build(
|
||||
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_READ
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sync_write(
|
||||
cls,
|
||||
ids_values: dict[int],
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Sync_Write" broadcast instruction.
|
||||
|
||||
The parameters for Sync_Write are:
|
||||
param[0] = start_address
|
||||
param[1] = data_length
|
||||
param[2] = [1st motor] ID
|
||||
param[2+1] = [1st motor] 1st Byte
|
||||
param[2+2] = [1st motor] 2nd Byte
|
||||
...
|
||||
param[5+X] = [1st motor] X-th Byte
|
||||
param[6] = [2nd motor] ID
|
||||
param[6+1] = [2nd motor] 1st Byte
|
||||
param[6+2] = [2nd motor] 2nd Byte
|
||||
...
|
||||
param[6+X] = [2nd motor] X-th Byte
|
||||
|
||||
And 'length' = ((number_of_params * 1 + data_length) + 4), where:
|
||||
+1 is for instruction byte,
|
||||
+1 is for the address byte,
|
||||
+1 is for the length bytes,
|
||||
+1 is for the checksum at the end.
|
||||
"""
|
||||
data = []
|
||||
for id_, value in ids_values.items():
|
||||
split_value = _split_into_byte_chunks(value, data_length)
|
||||
data += [id_, *split_value]
|
||||
params = [start_address, data_length, *data]
|
||||
length = len(ids_values) * (1 + data_length) + 4
|
||||
return cls.build(
|
||||
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_WRITE
|
||||
)
|
||||
|
||||
|
||||
class MockStatusPacket(MockFeetechPacket):
|
||||
"""
|
||||
Helper class to build valid Feetech Status Packets.
|
||||
|
||||
Status Packet structure
|
||||
(from https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf)
|
||||
|
||||
| Header | Packet ID | Length | Error | Params | Checksum |
|
||||
| --------- | --------- | ------ | ----- | ----------------- | -------- |
|
||||
| 0xFF 0xFF | ID | Len | Err | Param 1 … Param N | Sum |
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _build(cls, scs_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
|
||||
return [
|
||||
0xFF, 0xFF, # header
|
||||
scs_id, # servo id
|
||||
length, # length
|
||||
error, # status
|
||||
*params, # data bytes
|
||||
0x00, # placeholder for checksum
|
||||
] # fmt: skip
|
||||
|
||||
@classmethod
|
||||
def ping(cls, scs_id: int, error: int = 0) -> bytes:
|
||||
"""Builds a 'Ping' status packet.
|
||||
|
||||
Args:
|
||||
scs_id (int): ID of the servo responding.
|
||||
error (str, optional): Error to be returned. Defaults to "Success".
|
||||
|
||||
Returns:
|
||||
bytes: The raw 'Ping' status packet ready to be sent through serial.
|
||||
"""
|
||||
return cls.build(scs_id, params=[], length=2, error=error)
|
||||
|
||||
@classmethod
|
||||
def read(cls, scs_id: int, value: int, param_length: int, error: int = 0) -> bytes:
|
||||
"""Builds a 'Read' status packet.
|
||||
|
||||
Args:
|
||||
scs_id (int): ID of the servo responding.
|
||||
value (int): Desired value to be returned in the packet.
|
||||
param_length (int): The address length as reported in the control table.
|
||||
|
||||
Returns:
|
||||
bytes: The raw 'Sync Read' status packet ready to be sent through serial.
|
||||
"""
|
||||
params = _split_into_byte_chunks(value, param_length)
|
||||
length = param_length + 2
|
||||
return cls.build(scs_id, params=params, length=length, error=error)
|
||||
|
||||
|
||||
class MockPortHandler(scs.PortHandler):
|
||||
"""
|
||||
This class overwrite the 'setupPort' method of the Feetech PortHandler because it can specify
|
||||
baudrates that are not supported with a serial port on MacOS.
|
||||
"""
|
||||
|
||||
def setupPort(self, cflag_baud): # noqa: N802
|
||||
if self.is_open:
|
||||
self.closePort()
|
||||
|
||||
self.ser = serial.Serial(
|
||||
port=self.port_name,
|
||||
# baudrate=self.baudrate, <- This will fail on MacOS
|
||||
# parity = serial.PARITY_ODD,
|
||||
# stopbits = serial.STOPBITS_TWO,
|
||||
bytesize=serial.EIGHTBITS,
|
||||
timeout=0,
|
||||
)
|
||||
self.is_open = True
|
||||
self.ser.reset_input_buffer()
|
||||
self.tx_time_per_byte = (1000.0 / self.baudrate) * 10.0
|
||||
|
||||
return True
|
||||
|
||||
def setPacketTimeout(self, packet_length): # noqa: N802
|
||||
return patch_setPacketTimeout(self, packet_length)
|
||||
|
||||
|
||||
class MockMotors(MockSerial):
|
||||
"""
|
||||
This class will simulate physical motors by responding with valid status packets upon receiving some
|
||||
instruction packets. It is meant to test MotorsBus classes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def stubs(self) -> dict[str, WaitableStub]:
|
||||
return super().stubs
|
||||
|
||||
def stub(self, *, name=None, **kwargs):
|
||||
new_stub = WaitableStub(**kwargs)
|
||||
self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub
|
||||
return new_stub
|
||||
|
||||
def build_broadcast_ping_stub(self, ids: list[int] | None = None, num_invalid_try: int = 0) -> str:
|
||||
ping_request = MockInstructionPacket.ping(scs.BROADCAST_ID)
|
||||
return_packets = b"".join(MockStatusPacket.ping(id_) for id_ in ids)
|
||||
ping_response = self._build_send_fn(return_packets, num_invalid_try)
|
||||
stub_name = "Ping_" + "_".join([str(id_) for id_ in ids])
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=ping_request,
|
||||
send_fn=ping_response,
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0, error: int = 0) -> str:
|
||||
ping_request = MockInstructionPacket.ping(scs_id)
|
||||
return_packet = MockStatusPacket.ping(scs_id, error)
|
||||
ping_response = self._build_send_fn(return_packet, num_invalid_try)
|
||||
stub_name = f"Ping_{scs_id}_{error}"
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=ping_request,
|
||||
send_fn=ping_response,
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_read_stub(
|
||||
self,
|
||||
address: int,
|
||||
length: int,
|
||||
scs_id: int,
|
||||
value: int,
|
||||
reply: bool = True,
|
||||
error: int = 0,
|
||||
num_invalid_try: int = 0,
|
||||
) -> str:
|
||||
read_request = MockInstructionPacket.read(scs_id, address, length)
|
||||
return_packet = MockStatusPacket.read(scs_id, value, length, error) if reply else b""
|
||||
read_response = self._build_send_fn(return_packet, num_invalid_try)
|
||||
stub_name = f"Read_{address}_{length}_{scs_id}_{value}_{error}"
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=read_request,
|
||||
send_fn=read_response,
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_write_stub(
|
||||
self,
|
||||
address: int,
|
||||
length: int,
|
||||
scs_id: int,
|
||||
value: int,
|
||||
reply: bool = True,
|
||||
error: int = 0,
|
||||
num_invalid_try: int = 0,
|
||||
) -> str:
|
||||
sync_read_request = MockInstructionPacket.write(scs_id, value, address, length)
|
||||
return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) if reply else b""
|
||||
stub_name = f"Write_{address}_{length}_{scs_id}"
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=sync_read_request,
|
||||
send_fn=self._build_send_fn(return_packet, num_invalid_try),
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_sync_read_stub(
|
||||
self,
|
||||
address: int,
|
||||
length: int,
|
||||
ids_values: dict[int, int],
|
||||
reply: bool = True,
|
||||
num_invalid_try: int = 0,
|
||||
) -> str:
|
||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
||||
return_packets = (
|
||||
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
|
||||
if reply
|
||||
else b""
|
||||
)
|
||||
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
|
||||
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=sync_read_request,
|
||||
send_fn=sync_read_response,
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_sequential_sync_read_stub(
|
||||
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
|
||||
) -> str:
|
||||
sequence_length = len(next(iter(ids_values.values())))
|
||||
assert all(len(positions) == sequence_length for positions in ids_values.values())
|
||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
||||
sequential_packets = []
|
||||
for count in range(sequence_length):
|
||||
return_packets = b"".join(
|
||||
MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items()
|
||||
)
|
||||
sequential_packets.append(return_packets)
|
||||
|
||||
sync_read_response = self._build_sequential_send_fn(sequential_packets)
|
||||
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=sync_read_request,
|
||||
send_fn=sync_read_response,
|
||||
)
|
||||
return stub_name
|
||||
|
||||
def build_sync_write_stub(
|
||||
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
|
||||
) -> str:
|
||||
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
|
||||
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=sync_read_request,
|
||||
send_fn=self._build_send_fn(b"", num_invalid_try),
|
||||
)
|
||||
return stub_name
|
||||
|
||||
@staticmethod
|
||||
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
|
||||
def send_fn(_call_count: int) -> bytes:
|
||||
if num_invalid_try >= _call_count:
|
||||
return b""
|
||||
return packet
|
||||
|
||||
return send_fn
|
||||
|
||||
@staticmethod
|
||||
def _build_sequential_send_fn(packets: list[bytes]) -> Callable[[int], bytes]:
|
||||
def send_fn(_call_count: int) -> bytes:
|
||||
return packets[_call_count - 1]
|
||||
|
||||
return send_fn
|
||||
138
tests/mocks/mock_motors_bus.py
Normal file
138
tests/mocks/mock_motors_bus.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# ruff: noqa: N802
|
||||
|
||||
from lerobot.common.motors.motors_bus import (
|
||||
Motor,
|
||||
MotorsBus,
|
||||
)
|
||||
|
||||
DUMMY_CTRL_TABLE_1 = {
|
||||
"Firmware_Version": (0, 1),
|
||||
"Model_Number": (1, 2),
|
||||
"Present_Position": (3, 4),
|
||||
"Goal_Position": (11, 2),
|
||||
}
|
||||
|
||||
DUMMY_CTRL_TABLE_2 = {
|
||||
"Model_Number": (0, 2),
|
||||
"Firmware_Version": (2, 1),
|
||||
"Present_Position": (3, 4),
|
||||
"Present_Velocity": (7, 4),
|
||||
"Goal_Position": (11, 4),
|
||||
"Goal_Velocity": (15, 4),
|
||||
"Lock": (19, 1),
|
||||
}
|
||||
|
||||
DUMMY_MODEL_CTRL_TABLE = {
|
||||
"model_1": DUMMY_CTRL_TABLE_1,
|
||||
"model_2": DUMMY_CTRL_TABLE_2,
|
||||
"model_3": DUMMY_CTRL_TABLE_2,
|
||||
}
|
||||
|
||||
DUMMY_BAUDRATE_TABLE = {
|
||||
0: 1_000_000,
|
||||
1: 500_000,
|
||||
2: 250_000,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_BAUDRATE_TABLE = {
|
||||
"model_1": DUMMY_BAUDRATE_TABLE,
|
||||
"model_2": DUMMY_BAUDRATE_TABLE,
|
||||
"model_3": DUMMY_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
DUMMY_ENCODING_TABLE = {
|
||||
"Present_Position": 8,
|
||||
"Goal_Position": 10,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_ENCODING_TABLE = {
|
||||
"model_1": DUMMY_ENCODING_TABLE,
|
||||
"model_2": DUMMY_ENCODING_TABLE,
|
||||
"model_3": DUMMY_ENCODING_TABLE,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_NUMBER_TABLE = {
|
||||
"model_1": 1234,
|
||||
"model_2": 5678,
|
||||
"model_3": 5799,
|
||||
}
|
||||
|
||||
DUMMY_MODEL_RESOLUTION_TABLE = {
|
||||
"model_1": 4096,
|
||||
"model_2": 1024,
|
||||
"model_3": 4096,
|
||||
}
|
||||
|
||||
|
||||
class MockPortHandler:
|
||||
def __init__(self, port_name):
|
||||
self.is_open: bool = False
|
||||
self.baudrate: int
|
||||
self.packet_start_time: float
|
||||
self.packet_timeout: float
|
||||
self.tx_time_per_byte: float
|
||||
self.is_using: bool = False
|
||||
self.port_name: str = port_name
|
||||
self.ser = None
|
||||
|
||||
def openPort(self):
|
||||
self.is_open = True
|
||||
return self.is_open
|
||||
|
||||
def closePort(self):
|
||||
self.is_open = False
|
||||
|
||||
def clearPort(self): ...
|
||||
def setPortName(self, port_name):
|
||||
self.port_name = port_name
|
||||
|
||||
def getPortName(self):
|
||||
return self.port_name
|
||||
|
||||
def setBaudRate(self, baudrate):
|
||||
self.baudrate: baudrate
|
||||
|
||||
def getBaudRate(self):
|
||||
return self.baudrate
|
||||
|
||||
def getBytesAvailable(self): ...
|
||||
def readPort(self, length): ...
|
||||
def writePort(self, packet): ...
|
||||
def setPacketTimeout(self, packet_length): ...
|
||||
def setPacketTimeoutMillis(self, msec): ...
|
||||
def isPacketTimeout(self): ...
|
||||
def getCurrentTime(self): ...
|
||||
def getTimeSinceStart(self): ...
|
||||
def setupPort(self, cflag_baud): ...
|
||||
def getCFlagBaud(self, baudrate): ...
|
||||
|
||||
|
||||
class MockMotorsBus(MotorsBus):
|
||||
available_baudrates = [500_000, 1_000_000]
|
||||
default_timeout = 1000
|
||||
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
|
||||
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
|
||||
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
|
||||
model_number_table = DUMMY_MODEL_NUMBER_TABLE
|
||||
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
|
||||
normalized_data = ["Present_Position", "Goal_Position"]
|
||||
|
||||
def __init__(self, port: str, motors: dict[str, Motor]):
|
||||
super().__init__(port, motors)
|
||||
self.port_handler = MockPortHandler(port)
|
||||
|
||||
def _assert_protocol_is_compatible(self, instruction_name): ...
|
||||
def _handshake(self): ...
|
||||
def _find_single_motor(self, motor, initial_baudrate): ...
|
||||
def configure_motors(self): ...
|
||||
def is_calibrated(self): ...
|
||||
def read_calibration(self): ...
|
||||
def write_calibration(self, calibration_dict): ...
|
||||
def disable_torque(self, motors, num_retry): ...
|
||||
def _disable_torque(self, motor, model, num_retry): ...
|
||||
def enable_torque(self, motors, num_retry): ...
|
||||
def _get_half_turn_homings(self, positions): ...
|
||||
def _encode_sign(self, data_name, ids_values): ...
|
||||
def _decode_sign(self, data_name, ids_values): ...
|
||||
def _split_into_byte_chunks(self, value, length): ...
|
||||
def broadcast_ping(self, num_retry, raise_on_error): ...
|
||||
112
tests/mocks/mock_robot.py
Normal file
112
tests/mocks/mock_robot.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.cameras import CameraConfig, make_cameras_from_configs
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.robots import Robot, RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("mock_robot")
|
||||
@dataclass
|
||||
class MockRobotConfig(RobotConfig):
|
||||
n_motors: int = 3
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
random_values: bool = True
|
||||
static_values: list[float] | None = None
|
||||
calibrated: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_motors < 1:
|
||||
raise ValueError(self.n_motors)
|
||||
|
||||
if self.random_values and self.static_values is not None:
|
||||
raise ValueError("Choose either random values or static values")
|
||||
|
||||
if self.static_values is not None and len(self.static_values) != self.n_motors:
|
||||
raise ValueError("Specify the same number of static values as motors")
|
||||
|
||||
if len(self.cameras) > 0:
|
||||
raise NotImplementedError # TODO with the cameras refactor
|
||||
|
||||
|
||||
class MockRobot(Robot):
|
||||
"""Mock Robot to be used for testing."""
|
||||
|
||||
config_class = MockRobotConfig
|
||||
name = "mock_robot"
|
||||
|
||||
def __init__(self, config: MockRobotConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._is_connected = False
|
||||
self._is_calibrated = config.calibrated
|
||||
self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)]
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.motors}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._is_connected
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self._is_connected = True
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_calibrated = True
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.random_values:
|
||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
||||
else:
|
||||
return {
|
||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
||||
}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
return action
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_connected = False
|
||||
35
tests/mocks/mock_serial_patch.py
Normal file
35
tests/mocks/mock_serial_patch.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
from mock_serial.mock_serial import Stub
|
||||
|
||||
|
||||
class WaitableStub(Stub):
|
||||
"""
|
||||
In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time
|
||||
to read, match, and call the stub. In these situations, the test can fail randomly.
|
||||
|
||||
Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions.
|
||||
|
||||
Proposed fix:
|
||||
https://github.com/benthorner/mock_serial/pull/3
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._event = threading.Event()
|
||||
|
||||
def call(self):
|
||||
self._event.set()
|
||||
return super().call()
|
||||
|
||||
def wait_called(self, timeout: float = 1.0):
|
||||
return self._event.wait(timeout)
|
||||
|
||||
def wait_calls(self, min_calls: int = 1, timeout: float = 1.0):
|
||||
start = time.perf_counter()
|
||||
while time.perf_counter() - start < timeout:
|
||||
if self.calls >= min_calls:
|
||||
return self.calls
|
||||
time.sleep(0.005)
|
||||
raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.")
|
||||
94
tests/mocks/mock_teleop.py
Normal file
94
tests/mocks/mock_teleop.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.teleoperators import Teleoperator, TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("mock_teleop")
|
||||
@dataclass
|
||||
class MockTeleopConfig(TeleoperatorConfig):
|
||||
n_motors: int = 3
|
||||
random_values: bool = True
|
||||
static_values: list[float] | None = None
|
||||
calibrated: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_motors < 1:
|
||||
raise ValueError(self.n_motors)
|
||||
|
||||
if self.random_values and self.static_values is not None:
|
||||
raise ValueError("Choose either random values or static values")
|
||||
|
||||
if self.static_values is not None and len(self.static_values) != self.n_motors:
|
||||
raise ValueError("Specify the same number of static values as motors")
|
||||
|
||||
|
||||
class MockTeleop(Teleoperator):
|
||||
"""Mock Teleoperator to be used for testing."""
|
||||
|
||||
config_class = MockTeleopConfig
|
||||
name = "mock_teleop"
|
||||
|
||||
def __init__(self, config: MockTeleopConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._is_connected = False
|
||||
self._is_calibrated = config.calibrated
|
||||
self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)]
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.motors}
|
||||
|
||||
@cached_property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.motors}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._is_connected
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self._is_connected = True
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_calibrated = True
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.random_values:
|
||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
||||
else:
|
||||
return {
|
||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
||||
}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_connected = False
|
||||
@@ -1,107 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
|
||||
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
|
||||
|
||||
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
|
||||
from the original classes and functions (e.g. return types might be None instead of boolean).
|
||||
"""
|
||||
|
||||
# from dynamixel_sdk import COMM_SUCCESS
|
||||
|
||||
DEFAULT_BAUDRATE = 9_600
|
||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||
# `convert_bytes_to_value`
|
||||
del bytes # unused
|
||||
return value
|
||||
|
||||
|
||||
def get_default_motor_values(motor_index):
|
||||
return {
|
||||
# Key (int) are from X_SERIES_CONTROL_TABLE
|
||||
7: motor_index, # ID
|
||||
8: DEFAULT_BAUDRATE, # Baud_rate
|
||||
10: 0, # Drive_Mode
|
||||
64: 0, # Torque_Enable
|
||||
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
|
||||
# For other joints, 2560 will be autocorrected to be in calibration range
|
||||
132: 2560, # Present_Position
|
||||
}
|
||||
|
||||
|
||||
class PortHandler:
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
# factory default baudrate
|
||||
self.baudrate = DEFAULT_BAUDRATE
|
||||
|
||||
def openPort(self): # noqa: N802
|
||||
return True
|
||||
|
||||
def closePort(self): # noqa: N802
|
||||
pass
|
||||
|
||||
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
|
||||
del timeout_ms # unused
|
||||
|
||||
def getBaudRate(self): # noqa: N802
|
||||
return self.baudrate
|
||||
|
||||
def setBaudRate(self, baudrate): # noqa: N802
|
||||
self.baudrate = baudrate
|
||||
|
||||
|
||||
class PacketHandler:
|
||||
def __init__(self, protocol_version):
|
||||
del protocol_version # unused
|
||||
# Use packet_handler.data to communicate across Read and Write
|
||||
self.data = {}
|
||||
|
||||
|
||||
class GroupSyncRead:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
self.packet_handler = packet_handler
|
||||
|
||||
def addParam(self, motor_index): # noqa: N802
|
||||
# Initialize motor default values
|
||||
if motor_index not in self.packet_handler.data:
|
||||
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
|
||||
|
||||
def txRxPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def getData(self, index, address, bytes): # noqa: N802
|
||||
return self.packet_handler.data[index][address]
|
||||
|
||||
|
||||
class GroupSyncWrite:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
self.packet_handler = packet_handler
|
||||
self.address = address
|
||||
|
||||
def addParam(self, index, data): # noqa: N802
|
||||
# Initialize motor default values
|
||||
if index not in self.packet_handler.data:
|
||||
self.packet_handler.data[index] = get_default_motor_values(index)
|
||||
self.changeParam(index, data)
|
||||
|
||||
def txPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def changeParam(self, index, data): # noqa: N802
|
||||
self.packet_handler.data[index][self.address] = data
|
||||
@@ -1,125 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
|
||||
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
|
||||
|
||||
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
|
||||
from the original classes and functions (e.g. return types might be None instead of boolean).
|
||||
"""
|
||||
|
||||
# from dynamixel_sdk import COMM_SUCCESS
|
||||
|
||||
DEFAULT_BAUDRATE = 1_000_000
|
||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||
# `convert_bytes_to_value`
|
||||
del bytes # unused
|
||||
return value
|
||||
|
||||
|
||||
def get_default_motor_values(motor_index):
|
||||
return {
|
||||
# Key (int) are from SCS_SERIES_CONTROL_TABLE
|
||||
5: motor_index, # ID
|
||||
6: DEFAULT_BAUDRATE, # Baud_rate
|
||||
10: 0, # Drive_Mode
|
||||
21: 32, # P_Coefficient
|
||||
22: 32, # D_Coefficient
|
||||
23: 0, # I_Coefficient
|
||||
40: 0, # Torque_Enable
|
||||
41: 254, # Acceleration
|
||||
31: -2047, # Offset
|
||||
33: 0, # Mode
|
||||
55: 1, # Lock
|
||||
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
|
||||
# For other joints, 2560 will be autocorrected to be in calibration range
|
||||
56: 2560, # Present_Position
|
||||
58: 0, # Present_Speed
|
||||
69: 0, # Present_Current
|
||||
85: 150, # Maximum_Acceleration
|
||||
}
|
||||
|
||||
|
||||
class PortHandler:
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
# factory default baudrate
|
||||
self.baudrate = DEFAULT_BAUDRATE
|
||||
self.ser = SerialMock()
|
||||
|
||||
def openPort(self): # noqa: N802
|
||||
return True
|
||||
|
||||
def closePort(self): # noqa: N802
|
||||
pass
|
||||
|
||||
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
|
||||
del timeout_ms # unused
|
||||
|
||||
def getBaudRate(self): # noqa: N802
|
||||
return self.baudrate
|
||||
|
||||
def setBaudRate(self, baudrate): # noqa: N802
|
||||
self.baudrate = baudrate
|
||||
|
||||
|
||||
class PacketHandler:
|
||||
def __init__(self, protocol_version):
|
||||
del protocol_version # unused
|
||||
# Use packet_handler.data to communicate across Read and Write
|
||||
self.data = {}
|
||||
|
||||
|
||||
class GroupSyncRead:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
self.packet_handler = packet_handler
|
||||
|
||||
def addParam(self, motor_index): # noqa: N802
|
||||
# Initialize motor default values
|
||||
if motor_index not in self.packet_handler.data:
|
||||
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
|
||||
|
||||
def txRxPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def getData(self, index, address, bytes): # noqa: N802
|
||||
return self.packet_handler.data[index][address]
|
||||
|
||||
|
||||
class GroupSyncWrite:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
self.packet_handler = packet_handler
|
||||
self.address = address
|
||||
|
||||
def addParam(self, index, data): # noqa: N802
|
||||
if index not in self.packet_handler.data:
|
||||
self.packet_handler.data[index] = get_default_motor_values(index)
|
||||
self.changeParam(index, data)
|
||||
|
||||
def txPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def changeParam(self, index, data): # noqa: N802
|
||||
self.packet_handler.data[index][self.address] = data
|
||||
|
||||
|
||||
class SerialMock:
|
||||
def reset_output_buffer(self):
|
||||
pass
|
||||
|
||||
def reset_input_buffer(self):
|
||||
pass
|
||||
400
tests/motors/test_dynamixel.py
Normal file
400
tests/motors/test_dynamixel.py
Normal file
@@ -0,0 +1,400 @@
|
||||
import re
|
||||
import sys
|
||||
from typing import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.common.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus
|
||||
from lerobot.common.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE
|
||||
from lerobot.common.utils.encoding_utils import encode_twos_complement
|
||||
|
||||
try:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip("dynamixel_sdk not available", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_port_handler():
|
||||
if sys.platform == "darwin":
|
||||
with patch.object(dxl, "PortHandler", MockPortHandler):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_motors() -> Generator[MockMotors, None, None]:
|
||||
motors = MockMotors()
|
||||
motors.open()
|
||||
yield motors
|
||||
motors.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_motors() -> dict[str, Motor]:
|
||||
return {
|
||||
"dummy_1": Motor(1, "xl430-w250", MotorNormMode.RANGE_M100_100),
|
||||
"dummy_2": Motor(2, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"dummy_3": Motor(3, "xl330-m077", MotorNormMode.RANGE_M100_100),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_calibration(dummy_motors) -> dict[str, MotorCalibration]:
|
||||
drive_modes = [0, 1, 0]
|
||||
homings = [-709, -2006, 1624]
|
||||
mins = [43, 27, 145]
|
||||
maxes = [1335, 3608, 3999]
|
||||
calibration = {}
|
||||
for motor, m in dummy_motors.items():
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=drive_modes[m.id - 1],
|
||||
homing_offset=homings[m.id - 1],
|
||||
range_min=mins[m.id - 1],
|
||||
range_max=maxes[m.id - 1],
|
||||
)
|
||||
return calibration
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
|
||||
def test_autouse_patch():
|
||||
"""Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler."""
|
||||
assert dxl.PortHandler is MockPortHandler
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, length, expected",
|
||||
[
|
||||
(0x12, 1, [0x12]),
|
||||
(0x1234, 2, [0x34, 0x12]),
|
||||
(0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
|
||||
],
|
||||
ids=[
|
||||
"1 byte",
|
||||
"2 bytes",
|
||||
"4 bytes",
|
||||
],
|
||||
) # fmt: skip
|
||||
def test__split_into_byte_chunks(value, length, expected):
|
||||
bus = DynamixelMotorsBus("", {})
|
||||
assert bus._split_into_byte_chunks(value, length) == expected
|
||||
|
||||
|
||||
def test_abc_implementation(dummy_motors):
|
||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
||||
DynamixelMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("id_", [1, 2, 3])
|
||||
def test_ping(id_, mock_motors, dummy_motors):
|
||||
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
|
||||
stub = mock_motors.build_ping_stub(id_, expected_model_nb)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
ping_model_nb = bus.ping(id_)
|
||||
|
||||
assert ping_model_nb == expected_model_nb
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
def test_broadcast_ping(mock_motors, dummy_motors):
|
||||
models = {m.id: m.model for m in dummy_motors.values()}
|
||||
expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
|
||||
stub = mock_motors.build_broadcast_ping_stub(expected_model_nbs)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
ping_model_nbs = bus.broadcast_ping()
|
||||
|
||||
assert ping_model_nbs == expected_model_nbs
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"addr, length, id_, value",
|
||||
[
|
||||
(0, 1, 1, 2),
|
||||
(10, 2, 2, 999),
|
||||
(42, 4, 3, 1337),
|
||||
],
|
||||
)
|
||||
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
|
||||
stub = mock_motors.build_read_stub(addr, length, id_, value)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
read_value, _, _ = bus._read(addr, length, id_)
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
assert read_value == value
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||
def test__read_error(raise_on_error, mock_motors, dummy_motors):
|
||||
addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT)
|
||||
stub = mock_motors.build_read_stub(addr, length, id_, value, error=error)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
if raise_on_error:
|
||||
with pytest.raises(
|
||||
RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!")
|
||||
):
|
||||
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||
else:
|
||||
_, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||
assert read_error == error
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
|
||||
addr, length, id_, value = (10, 4, 1, 1337)
|
||||
stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
if raise_on_error:
|
||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||
else:
|
||||
_, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||
assert read_comm == dxl.COMM_RX_TIMEOUT
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"addr, length, id_, value",
|
||||
[
|
||||
(0, 1, 1, 2),
|
||||
(10, 2, 2, 999),
|
||||
(42, 4, 3, 1337),
|
||||
],
|
||||
)
|
||||
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
|
||||
stub = mock_motors.build_write_stub(addr, length, id_, value)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
comm, error = bus._write(addr, length, id_, value)
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
assert comm == dxl.COMM_SUCCESS
|
||||
assert error == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||
def test__write_error(raise_on_error, mock_motors, dummy_motors):
|
||||
addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT)
|
||||
stub = mock_motors.build_write_stub(addr, length, id_, value, error=error)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
if raise_on_error:
|
||||
with pytest.raises(
|
||||
RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!")
|
||||
):
|
||||
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||
else:
|
||||
_, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||
assert write_error == error
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
|
||||
addr, length, id_, value = (10, 4, 1, 1337)
|
||||
stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
if raise_on_error:
|
||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||
else:
|
||||
write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||
assert write_comm == dxl.COMM_RX_TIMEOUT
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"addr, length, ids_values",
|
||||
[
|
||||
(0, 1, {1: 4}),
|
||||
(10, 2, {1: 1337, 2: 42}),
|
||||
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
||||
],
|
||||
ids=["1 motor", "2 motors", "3 motors"],
|
||||
)
|
||||
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
|
||||
stub = mock_motors.build_sync_read_stub(addr, length, ids_values)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
read_values, _ = bus._sync_read(addr, length, list(ids_values))
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
assert read_values == ids_values
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
|
||||
addr, length, ids_values = (10, 4, {1: 1337})
|
||||
stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
if raise_on_error:
|
||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||
bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
||||
else:
|
||||
_, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
||||
assert read_comm == dxl.COMM_RX_TIMEOUT
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"addr, length, ids_values",
|
||||
[
|
||||
(0, 1, {1: 4}),
|
||||
(10, 2, {1: 1337, 2: 42}),
|
||||
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
||||
],
|
||||
ids=["1 motor", "2 motors", "3 motors"],
|
||||
)
|
||||
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
|
||||
stub = mock_motors.build_sync_write_stub(addr, length, ids_values)
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
comm = bus._sync_write(addr, length, ids_values)
|
||||
|
||||
assert mock_motors.stubs[stub].wait_called()
|
||||
assert comm == dxl.COMM_SUCCESS
|
||||
|
||||
|
||||
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
|
||||
drive_modes = {m.id: m.drive_mode for m in dummy_calibration.values()}
|
||||
encoded_homings = {m.id: encode_twos_complement(m.homing_offset, 4) for m in dummy_calibration.values()}
|
||||
mins = {m.id: m.range_min for m in dummy_calibration.values()}
|
||||
maxes = {m.id: m.range_max for m in dummy_calibration.values()}
|
||||
drive_modes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Drive_Mode"], drive_modes)
|
||||
offsets_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings)
|
||||
mins_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins)
|
||||
maxes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes)
|
||||
bus = DynamixelMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors=dummy_motors,
|
||||
calibration=dummy_calibration,
|
||||
)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
is_calibrated = bus.is_calibrated
|
||||
|
||||
assert is_calibrated
|
||||
assert mock_motors.stubs[drive_modes_stub].called
|
||||
assert mock_motors.stubs[offsets_stub].called
|
||||
assert mock_motors.stubs[mins_stub].called
|
||||
assert mock_motors.stubs[maxes_stub].called
|
||||
|
||||
|
||||
def test_reset_calibration(mock_motors, dummy_motors):
|
||||
write_homing_stubs = []
|
||||
write_mins_stubs = []
|
||||
write_maxes_stubs = []
|
||||
for motor in dummy_motors.values():
|
||||
write_homing_stubs.append(
|
||||
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
|
||||
)
|
||||
write_mins_stubs.append(
|
||||
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
|
||||
)
|
||||
write_maxes_stubs.append(
|
||||
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
|
||||
)
|
||||
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
bus.reset_calibration()
|
||||
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_maxes_stubs)
|
||||
|
||||
|
||||
def test_set_half_turn_homings(mock_motors, dummy_motors):
|
||||
"""
|
||||
For this test, we assume that the homing offsets are already 0 such that
|
||||
Present_Position == Actual_Position
|
||||
"""
|
||||
current_positions = {
|
||||
1: 1337,
|
||||
2: 42,
|
||||
3: 3672,
|
||||
}
|
||||
expected_homings = {
|
||||
1: 710, # 2047 - 1337
|
||||
2: 2005, # 2047 - 42
|
||||
3: -1625, # 2047 - 3672
|
||||
}
|
||||
read_pos_stub = mock_motors.build_sync_read_stub(
|
||||
*X_SERIES_CONTROL_TABLE["Present_Position"], current_positions
|
||||
)
|
||||
write_homing_stubs = []
|
||||
for id_, homing in expected_homings.items():
|
||||
encoded_homing = encode_twos_complement(homing, 4)
|
||||
stub = mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing)
|
||||
write_homing_stubs.append(stub)
|
||||
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
bus.reset_calibration = MagicMock()
|
||||
|
||||
bus.set_half_turn_homings()
|
||||
|
||||
bus.reset_calibration.assert_called_once()
|
||||
assert mock_motors.stubs[read_pos_stub].called
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
||||
|
||||
|
||||
def test_record_ranges_of_motion(mock_motors, dummy_motors):
|
||||
positions = {
|
||||
1: [351, 42, 1337],
|
||||
2: [28, 3600, 2444],
|
||||
3: [4002, 2999, 146],
|
||||
}
|
||||
expected_mins = {
|
||||
"dummy_1": 42,
|
||||
"dummy_2": 28,
|
||||
"dummy_3": 146,
|
||||
}
|
||||
expected_maxes = {
|
||||
"dummy_1": 1337,
|
||||
"dummy_2": 3600,
|
||||
"dummy_3": 4002,
|
||||
}
|
||||
read_pos_stub = mock_motors.build_sequential_sync_read_stub(
|
||||
*X_SERIES_CONTROL_TABLE["Present_Position"], positions
|
||||
)
|
||||
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
mins, maxes = bus.record_ranges_of_motion(display_values=False)
|
||||
|
||||
assert mock_motors.stubs[read_pos_stub].calls == 3
|
||||
assert mins == expected_mins
|
||||
assert maxes == expected_maxes
|
||||
443
tests/motors/test_feetech.py
Normal file
443
tests/motors/test_feetech.py
Normal file
@@ -0,0 +1,443 @@
|
||||
import re
|
||||
import sys
|
||||
from typing import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.common.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus
|
||||
from lerobot.common.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE
|
||||
from lerobot.common.utils.encoding_utils import encode_sign_magnitude
|
||||
|
||||
try:
|
||||
import scservo_sdk as scs
|
||||
|
||||
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip("scservo_sdk not available", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_port_handler():
|
||||
if sys.platform == "darwin":
|
||||
with patch.object(scs, "PortHandler", MockPortHandler):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_motors() -> Generator[MockMotors, None, None]:
|
||||
motors = MockMotors()
|
||||
motors.open()
|
||||
yield motors
|
||||
motors.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_motors() -> dict[str, Motor]:
|
||||
return {
|
||||
"dummy_1": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"dummy_2": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"dummy_3": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_calibration(dummy_motors) -> dict[str, MotorCalibration]:
|
||||
homings = [-709, -2006, 1624]
|
||||
mins = [43, 27, 145]
|
||||
maxes = [1335, 3608, 3999]
|
||||
calibration = {}
|
||||
for motor, m in dummy_motors.items():
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=0,
|
||||
homing_offset=homings[m.id - 1],
|
||||
range_min=mins[m.id - 1],
|
||||
range_max=maxes[m.id - 1],
|
||||
)
|
||||
return calibration
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
|
||||
def test_autouse_patch():
|
||||
"""Ensures that the autouse fixture correctly patches scs.PortHandler with MockPortHandler."""
|
||||
assert scs.PortHandler is MockPortHandler
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"protocol, value, length, expected",
|
||||
[
|
||||
(0, 0x12, 1, [0x12]),
|
||||
(1, 0x12, 1, [0x12]),
|
||||
(0, 0x1234, 2, [0x34, 0x12]),
|
||||
(1, 0x1234, 2, [0x12, 0x34]),
|
||||
(0, 0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
|
||||
(1, 0x12345678, 4, [0x56, 0x78, 0x12, 0x34]),
|
||||
],
|
||||
ids=[
|
||||
"P0: 1 byte",
|
||||
"P1: 1 byte",
|
||||
"P0: 2 bytes",
|
||||
"P1: 2 bytes",
|
||||
"P0: 4 bytes",
|
||||
"P1: 4 bytes",
|
||||
],
|
||||
) # fmt: skip
|
||||
def test__split_into_byte_chunks(protocol, value, length, expected):
|
||||
bus = FeetechMotorsBus("", {}, protocol_version=protocol)
|
||||
assert bus._split_into_byte_chunks(value, length) == expected
|
||||
|
||||
|
||||
def test_abc_implementation(dummy_motors):
|
||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
||||
FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("id_", [1, 2, 3])
|
||||
def test_ping(id_, mock_motors, dummy_motors):
|
||||
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
|
||||
addr, length = MODEL_NUMBER
|
||||
ping_stub = mock_motors.build_ping_stub(id_)
|
||||
mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb)
|
||||
bus = FeetechMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors=dummy_motors,
|
||||
)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
ping_model_nb = bus.ping(id_)
|
||||
|
||||
assert ping_model_nb == expected_model_nb
|
||||
assert mock_motors.stubs[ping_stub].called
|
||||
assert mock_motors.stubs[mobel_nb_stub].called
|
||||
|
||||
|
||||
def test_broadcast_ping(mock_motors, dummy_motors):
|
||||
models = {m.id: m.model for m in dummy_motors.values()}
|
||||
addr, length = MODEL_NUMBER
|
||||
ping_stub = mock_motors.build_broadcast_ping_stub(list(models))
|
||||
mobel_nb_stubs = []
|
||||
expected_model_nbs = {}
|
||||
for id_, model in models.items():
|
||||
model_nb = MODEL_NUMBER_TABLE[model]
|
||||
stub = mock_motors.build_read_stub(addr, length, id_, model_nb)
|
||||
expected_model_nbs[id_] = model_nb
|
||||
mobel_nb_stubs.append(stub)
|
||||
bus = FeetechMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors=dummy_motors,
|
||||
)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
ping_model_nbs = bus.broadcast_ping()
|
||||
|
||||
assert ping_model_nbs == expected_model_nbs
|
||||
assert mock_motors.stubs[ping_stub].called
|
||||
assert all(mock_motors.stubs[stub].called for stub in mobel_nb_stubs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"addr, length, id_, value",
|
||||
[
|
||||
(0, 1, 1, 2),
|
||||
(10, 2, 2, 999),
|
||||
(42, 4, 3, 1337),
|
||||
],
|
||||
)
|
||||
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
|
||||
stub = mock_motors.build_read_stub(addr, length, id_, value)
|
||||
bus = FeetechMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors=dummy_motors,
|
||||
)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
read_value, _, _ = bus._read(addr, length, id_)
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
assert read_value == value
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||
def test__read_error(raise_on_error, mock_motors, dummy_motors):
|
||||
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
|
||||
stub = mock_motors.build_read_stub(addr, length, id_, value, error=error)
|
||||
bus = FeetechMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors=dummy_motors,
|
||||
)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
if raise_on_error:
|
||||
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
|
||||
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||
else:
|
||||
_, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||
assert read_error == error
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
|
||||
addr, length, id_, value = (10, 4, 1, 1337)
|
||||
stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
|
||||
bus = FeetechMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors=dummy_motors,
|
||||
)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
if raise_on_error:
|
||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||
else:
|
||||
_, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||
assert read_comm == scs.COMM_RX_TIMEOUT
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"addr, length, id_, value",
|
||||
[
|
||||
(0, 1, 1, 2),
|
||||
(10, 2, 2, 999),
|
||||
(42, 4, 3, 1337),
|
||||
],
|
||||
)
|
||||
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
|
||||
stub = mock_motors.build_write_stub(addr, length, id_, value)
|
||||
bus = FeetechMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors=dummy_motors,
|
||||
)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
comm, error = bus._write(addr, length, id_, value)
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
assert comm == scs.COMM_SUCCESS
|
||||
assert error == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||
def test__write_error(raise_on_error, mock_motors, dummy_motors):
|
||||
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
|
||||
stub = mock_motors.build_write_stub(addr, length, id_, value, error=error)
|
||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
if raise_on_error:
|
||||
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
|
||||
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||
else:
|
||||
_, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||
assert write_error == error
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
|
||||
addr, length, id_, value = (10, 4, 1, 1337)
|
||||
stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
|
||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
if raise_on_error:
|
||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||
else:
|
||||
write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||
assert write_comm == scs.COMM_RX_TIMEOUT
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"addr, length, ids_values",
|
||||
[
|
||||
(0, 1, {1: 4}),
|
||||
(10, 2, {1: 1337, 2: 42}),
|
||||
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
||||
],
|
||||
ids=["1 motor", "2 motors", "3 motors"],
|
||||
)
|
||||
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
|
||||
stub = mock_motors.build_sync_read_stub(addr, length, ids_values)
|
||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
read_values, _ = bus._sync_read(addr, length, list(ids_values))
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
assert read_values == ids_values
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
|
||||
addr, length, ids_values = (10, 4, {1: 1337})
|
||||
stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
|
||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
if raise_on_error:
|
||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||
bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
||||
else:
|
||||
_, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
||||
assert read_comm == scs.COMM_RX_TIMEOUT
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"addr, length, ids_values",
|
||||
[
|
||||
(0, 1, {1: 4}),
|
||||
(10, 2, {1: 1337, 2: 42}),
|
||||
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
||||
],
|
||||
ids=["1 motor", "2 motors", "3 motors"],
|
||||
)
|
||||
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
|
||||
stub = mock_motors.build_sync_write_stub(addr, length, ids_values)
|
||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
comm = bus._sync_write(addr, length, ids_values)
|
||||
|
||||
assert mock_motors.stubs[stub].wait_called()
|
||||
assert comm == scs.COMM_SUCCESS
|
||||
|
||||
|
||||
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
|
||||
mins_stubs, maxes_stubs, homings_stubs = [], [], []
|
||||
for cal in dummy_calibration.values():
|
||||
mins_stubs.append(
|
||||
mock_motors.build_read_stub(
|
||||
*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], cal.id, cal.range_min
|
||||
)
|
||||
)
|
||||
maxes_stubs.append(
|
||||
mock_motors.build_read_stub(
|
||||
*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], cal.id, cal.range_max
|
||||
)
|
||||
)
|
||||
homings_stubs.append(
|
||||
mock_motors.build_read_stub(
|
||||
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"],
|
||||
cal.id,
|
||||
encode_sign_magnitude(cal.homing_offset, 11),
|
||||
)
|
||||
)
|
||||
|
||||
bus = FeetechMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors=dummy_motors,
|
||||
calibration=dummy_calibration,
|
||||
)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
is_calibrated = bus.is_calibrated
|
||||
|
||||
assert is_calibrated
|
||||
assert all(mock_motors.stubs[stub].called for stub in mins_stubs)
|
||||
assert all(mock_motors.stubs[stub].called for stub in maxes_stubs)
|
||||
assert all(mock_motors.stubs[stub].called for stub in homings_stubs)
|
||||
|
||||
|
||||
def test_reset_calibration(mock_motors, dummy_motors):
|
||||
write_homing_stubs = []
|
||||
write_mins_stubs = []
|
||||
write_maxes_stubs = []
|
||||
for motor in dummy_motors.values():
|
||||
write_homing_stubs.append(
|
||||
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
|
||||
)
|
||||
write_mins_stubs.append(
|
||||
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
|
||||
)
|
||||
write_maxes_stubs.append(
|
||||
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
|
||||
)
|
||||
|
||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
bus.reset_calibration()
|
||||
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_maxes_stubs)
|
||||
|
||||
|
||||
def test_set_half_turn_homings(mock_motors, dummy_motors):
|
||||
"""
|
||||
For this test, we assume that the homing offsets are already 0 such that
|
||||
Present_Position == Actual_Position
|
||||
"""
|
||||
current_positions = {
|
||||
1: 1337,
|
||||
2: 42,
|
||||
3: 3672,
|
||||
}
|
||||
expected_homings = {
|
||||
1: -710, # 1337 - 2047
|
||||
2: -2005, # 42 - 2047
|
||||
3: 1625, # 3672 - 2047
|
||||
}
|
||||
read_pos_stub = mock_motors.build_sync_read_stub(
|
||||
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], current_positions
|
||||
)
|
||||
write_homing_stubs = []
|
||||
for id_, homing in expected_homings.items():
|
||||
encoded_homing = encode_sign_magnitude(homing, 11)
|
||||
stub = mock_motors.build_write_stub(
|
||||
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing
|
||||
)
|
||||
write_homing_stubs.append(stub)
|
||||
|
||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
bus.reset_calibration = MagicMock()
|
||||
|
||||
bus.set_half_turn_homings()
|
||||
|
||||
bus.reset_calibration.assert_called_once()
|
||||
assert mock_motors.stubs[read_pos_stub].called
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
||||
|
||||
|
||||
def test_record_ranges_of_motion(mock_motors, dummy_motors):
|
||||
positions = {
|
||||
1: [351, 42, 1337],
|
||||
2: [28, 3600, 2444],
|
||||
3: [4002, 2999, 146],
|
||||
}
|
||||
expected_mins = {
|
||||
"dummy_1": 42,
|
||||
"dummy_2": 28,
|
||||
"dummy_3": 146,
|
||||
}
|
||||
expected_maxes = {
|
||||
"dummy_1": 1337,
|
||||
"dummy_2": 3600,
|
||||
"dummy_3": 4002,
|
||||
}
|
||||
stub = mock_motors.build_sequential_sync_read_stub(
|
||||
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions
|
||||
)
|
||||
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
mins, maxes = bus.record_ranges_of_motion(display_values=False)
|
||||
|
||||
assert mock_motors.stubs[stub].calls == 3
|
||||
assert mins == expected_mins
|
||||
assert maxes == expected_maxes
|
||||
@@ -1,160 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for physical motors and their mocked versions.
|
||||
If the physical motors are not connected to the computer, or not working,
|
||||
the test will be skipped.
|
||||
|
||||
Example of running a specific test:
|
||||
```bash
|
||||
pytest -sx tests/test_motors.py::test_find_port
|
||||
pytest -sx tests/test_motors.py::test_motors_bus
|
||||
```
|
||||
|
||||
Example of running test on real dynamixel motors connected to the computer:
|
||||
```bash
|
||||
pytest -sx 'tests/test_motors.py::test_motors_bus[dynamixel-False]'
|
||||
```
|
||||
|
||||
Example of running test on a mocked version of dynamixel motors:
|
||||
```bash
|
||||
pytest -sx 'tests/test_motors.py::test_motors_bus[dynamixel-True]'
|
||||
```
|
||||
"""
|
||||
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
# TODO(rcadene): test calibration
|
||||
# TODO(rcadene): add compatibility with other motors bus
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.scripts.find_motors_bus_port import find_port
|
||||
from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("motor_type, mock", TEST_MOTOR_TYPES)
|
||||
@require_motor
|
||||
def test_find_port(request, motor_type, mock):
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
with pytest.raises(OSError):
|
||||
find_port()
|
||||
else:
|
||||
find_port()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("motor_type, mock", TEST_MOTOR_TYPES)
|
||||
@require_motor
|
||||
def test_configure_motors_all_ids_1(request, motor_type, mock):
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
if motor_type == "dynamixel":
|
||||
# see X_SERIES_BAUDRATE_TABLE
|
||||
smaller_baudrate = 9_600
|
||||
smaller_baudrate_value = 0
|
||||
elif motor_type == "feetech":
|
||||
# see SCS_SERIES_BAUDRATE_TABLE
|
||||
smaller_baudrate = 19_200
|
||||
smaller_baudrate_value = 7
|
||||
else:
|
||||
raise ValueError(motor_type)
|
||||
|
||||
input("Are you sure you want to re-configure the motors? Press enter to continue...")
|
||||
# This test expect the configuration was already correct.
|
||||
motors_bus = make_motors_bus(motor_type, mock=mock)
|
||||
motors_bus.connect()
|
||||
motors_bus.write("Baud_Rate", [smaller_baudrate_value] * len(motors_bus.motors))
|
||||
|
||||
motors_bus.set_bus_baudrate(smaller_baudrate)
|
||||
motors_bus.write("ID", [1] * len(motors_bus.motors))
|
||||
del motors_bus
|
||||
|
||||
# Test configure
|
||||
motors_bus = make_motors_bus(motor_type, mock=mock)
|
||||
motors_bus.connect()
|
||||
assert motors_bus.are_motors_configured()
|
||||
del motors_bus
|
||||
|
||||
|
||||
@pytest.mark.parametrize("motor_type, mock", TEST_MOTOR_TYPES)
|
||||
@require_motor
|
||||
def test_motors_bus(request, motor_type, mock):
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
motors_bus = make_motors_bus(motor_type, mock=mock)
|
||||
|
||||
# Test reading and writing before connecting raises an error
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
motors_bus.read("Torque_Enable")
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
motors_bus.write("Torque_Enable", 1)
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
motors_bus.disconnect()
|
||||
|
||||
# Test deleting the object without connecting first
|
||||
del motors_bus
|
||||
|
||||
# Test connecting
|
||||
motors_bus = make_motors_bus(motor_type, mock=mock)
|
||||
motors_bus.connect()
|
||||
|
||||
# Test connecting twice raises an error
|
||||
with pytest.raises(RobotDeviceAlreadyConnectedError):
|
||||
motors_bus.connect()
|
||||
|
||||
# Test disabling torque and reading torque on all motors
|
||||
motors_bus.write("Torque_Enable", 0)
|
||||
values = motors_bus.read("Torque_Enable")
|
||||
assert isinstance(values, np.ndarray)
|
||||
assert len(values) == len(motors_bus.motors)
|
||||
assert (values == 0).all()
|
||||
|
||||
# Test writing torque on a specific motor
|
||||
motors_bus.write("Torque_Enable", 1, "gripper")
|
||||
|
||||
# Test reading torque from this specific motor. It is now 1
|
||||
values = motors_bus.read("Torque_Enable", "gripper")
|
||||
assert len(values) == 1
|
||||
assert values[0] == 1
|
||||
|
||||
# Test reading torque from all motors. It is 1 for the specific motor,
|
||||
# and 0 on the others.
|
||||
values = motors_bus.read("Torque_Enable")
|
||||
gripper_index = motors_bus.motor_names.index("gripper")
|
||||
assert values[gripper_index] == 1
|
||||
assert values.sum() == 1 # gripper is the only motor to have torque 1
|
||||
|
||||
# Test writing torque on all motors and it is 1 for all.
|
||||
motors_bus.write("Torque_Enable", 1)
|
||||
values = motors_bus.read("Torque_Enable")
|
||||
assert (values == 1).all()
|
||||
|
||||
# Test ordering the motors to move slightly (+1 value among 4096) and this move
|
||||
# can be executed and seen by the motor position sensor
|
||||
values = motors_bus.read("Present_Position")
|
||||
motors_bus.write("Goal_Position", values + 1)
|
||||
# Give time for the motors to move to the goal position
|
||||
time.sleep(1)
|
||||
new_values = motors_bus.read("Present_Position")
|
||||
assert (new_values == values).all()
|
||||
342
tests/motors/test_motors_bus.py
Normal file
342
tests/motors/test_motors_bus.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import re
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.motors.motors_bus import (
|
||||
Motor,
|
||||
MotorNormMode,
|
||||
assert_same_address,
|
||||
get_address,
|
||||
get_ctrl_table,
|
||||
)
|
||||
from tests.mocks.mock_motors_bus import (
|
||||
DUMMY_CTRL_TABLE_1,
|
||||
DUMMY_CTRL_TABLE_2,
|
||||
DUMMY_MODEL_CTRL_TABLE,
|
||||
MockMotorsBus,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_motors() -> dict[str, Motor]:
|
||||
return {
|
||||
"dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100),
|
||||
"dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100),
|
||||
"dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
|
||||
def test_get_ctrl_table():
|
||||
model = "model_1"
|
||||
ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
|
||||
assert ctrl_table == DUMMY_CTRL_TABLE_1
|
||||
|
||||
|
||||
def test_get_ctrl_table_error():
|
||||
model = "model_99"
|
||||
with pytest.raises(KeyError, match=f"Control table for {model=} not found."):
|
||||
get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
|
||||
|
||||
|
||||
def test_get_address():
|
||||
addr, n_bytes = get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", "Firmware_Version")
|
||||
assert addr == 0
|
||||
assert n_bytes == 1
|
||||
|
||||
|
||||
def test_get_address_error():
|
||||
model = "model_1"
|
||||
data_name = "Lock"
|
||||
with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."):
|
||||
get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", data_name)
|
||||
|
||||
|
||||
def test_assert_same_address():
|
||||
models = ["model_1", "model_2"]
|
||||
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position")
|
||||
|
||||
|
||||
def test_assert_same_length_different_addresses():
|
||||
models = ["model_1", "model_2"]
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match=re.escape("At least two motor models use a different address"),
|
||||
):
|
||||
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number")
|
||||
|
||||
|
||||
def test_assert_same_address_different_length():
|
||||
models = ["model_1", "model_2"]
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match=re.escape("At least two motor models use a different bytes representation"),
|
||||
):
|
||||
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Goal_Position")
|
||||
|
||||
|
||||
def test__serialize_data_invalid_length():
|
||||
bus = MockMotorsBus("", {})
|
||||
with pytest.raises(NotImplementedError):
|
||||
bus._serialize_data(100, 3)
|
||||
|
||||
|
||||
def test__serialize_data_negative_numbers():
|
||||
bus = MockMotorsBus("", {})
|
||||
with pytest.raises(ValueError):
|
||||
bus._serialize_data(-1, 1)
|
||||
|
||||
|
||||
def test__serialize_data_large_number():
|
||||
bus = MockMotorsBus("", {})
|
||||
with pytest.raises(ValueError):
|
||||
bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_name, id_, value",
|
||||
[
|
||||
("Firmware_Version", 1, 14),
|
||||
("Model_Number", 1, 5678),
|
||||
("Present_Position", 2, 1337),
|
||||
("Present_Velocity", 3, 42),
|
||||
],
|
||||
)
|
||||
def test_read(data_name, id_, value, dummy_motors):
|
||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||
|
||||
with (
|
||||
patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read,
|
||||
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
|
||||
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
|
||||
):
|
||||
returned_value = bus.read(data_name, f"dummy_{id_}")
|
||||
|
||||
assert returned_value == value
|
||||
mock__read.assert_called_once_with(
|
||||
addr,
|
||||
length,
|
||||
id_,
|
||||
num_retry=0,
|
||||
raise_on_error=True,
|
||||
err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.",
|
||||
)
|
||||
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
|
||||
if data_name in bus.normalized_data:
|
||||
mock__normalize.assert_called_once_with({id_: value})
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_name, id_, value",
|
||||
[
|
||||
("Goal_Position", 1, 1337),
|
||||
("Goal_Velocity", 2, 3682),
|
||||
("Lock", 3, 1),
|
||||
],
|
||||
)
|
||||
def test_write(data_name, id_, value, dummy_motors):
|
||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||
|
||||
with (
|
||||
patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write,
|
||||
patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign,
|
||||
patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize,
|
||||
):
|
||||
bus.write(data_name, f"dummy_{id_}", value)
|
||||
|
||||
mock__write.assert_called_once_with(
|
||||
addr,
|
||||
length,
|
||||
id_,
|
||||
value,
|
||||
num_retry=0,
|
||||
raise_on_error=True,
|
||||
err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.",
|
||||
)
|
||||
mock__encode_sign.assert_called_once_with(data_name, {id_: value})
|
||||
if data_name in bus.normalized_data:
|
||||
mock__unnormalize.assert_called_once_with({id_: value})
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_name, id_, value",
|
||||
[
|
||||
("Firmware_Version", 1, 14),
|
||||
("Model_Number", 1, 5678),
|
||||
("Present_Position", 2, 1337),
|
||||
("Present_Velocity", 3, 42),
|
||||
],
|
||||
)
|
||||
def test_sync_read_by_str(data_name, id_, value, dummy_motors):
|
||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||
ids = [id_]
|
||||
expected_value = {f"dummy_{id_}": value}
|
||||
|
||||
with (
|
||||
patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read,
|
||||
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
|
||||
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
|
||||
):
|
||||
returned_dict = bus.sync_read(data_name, f"dummy_{id_}")
|
||||
|
||||
assert returned_dict == expected_value
|
||||
mock__sync_read.assert_called_once_with(
|
||||
addr,
|
||||
length,
|
||||
ids,
|
||||
num_retry=0,
|
||||
raise_on_error=True,
|
||||
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
|
||||
)
|
||||
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
|
||||
if data_name in bus.normalized_data:
|
||||
mock__normalize.assert_called_once_with({id_: value})
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_name, ids_values",
|
||||
[
|
||||
("Model_Number", {1: 5678}),
|
||||
("Present_Position", {1: 1337, 2: 42}),
|
||||
("Present_Velocity", {1: 1337, 2: 42, 3: 4016}),
|
||||
],
|
||||
ids=["1 motor", "2 motors", "3 motors"],
|
||||
)
|
||||
def test_sync_read_by_list(data_name, ids_values, dummy_motors):
|
||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||
ids = list(ids_values)
|
||||
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
|
||||
|
||||
with (
|
||||
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
|
||||
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
|
||||
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
|
||||
):
|
||||
returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids])
|
||||
|
||||
assert returned_dict == expected_values
|
||||
mock__sync_read.assert_called_once_with(
|
||||
addr,
|
||||
length,
|
||||
ids,
|
||||
num_retry=0,
|
||||
raise_on_error=True,
|
||||
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
|
||||
)
|
||||
mock__decode_sign.assert_called_once_with(data_name, ids_values)
|
||||
if data_name in bus.normalized_data:
|
||||
mock__normalize.assert_called_once_with(ids_values)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_name, ids_values",
|
||||
[
|
||||
("Model_Number", {1: 5678, 2: 5799, 3: 5678}),
|
||||
("Present_Position", {1: 1337, 2: 42, 3: 4016}),
|
||||
("Goal_Position", {1: 4008, 2: 199, 3: 3446}),
|
||||
],
|
||||
ids=["Model_Number", "Present_Position", "Goal_Position"],
|
||||
)
|
||||
def test_sync_read_by_none(data_name, ids_values, dummy_motors):
|
||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||
ids = list(ids_values)
|
||||
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
|
||||
|
||||
with (
|
||||
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
|
||||
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
|
||||
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
|
||||
):
|
||||
returned_dict = bus.sync_read(data_name)
|
||||
|
||||
assert returned_dict == expected_values
|
||||
mock__sync_read.assert_called_once_with(
|
||||
addr,
|
||||
length,
|
||||
ids,
|
||||
num_retry=0,
|
||||
raise_on_error=True,
|
||||
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
|
||||
)
|
||||
mock__decode_sign.assert_called_once_with(data_name, ids_values)
|
||||
if data_name in bus.normalized_data:
|
||||
mock__normalize.assert_called_once_with(ids_values)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_name, value",
|
||||
[
|
||||
("Goal_Position", 500),
|
||||
("Goal_Velocity", 4010),
|
||||
("Lock", 0),
|
||||
],
|
||||
)
|
||||
def test_sync_write_by_single_value(data_name, value, dummy_motors):
|
||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||
ids_values = {m.id: value for m in dummy_motors.values()}
|
||||
|
||||
with (
|
||||
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
|
||||
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
|
||||
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
|
||||
):
|
||||
bus.sync_write(data_name, value)
|
||||
|
||||
mock__sync_write.assert_called_once_with(
|
||||
addr,
|
||||
length,
|
||||
ids_values,
|
||||
num_retry=0,
|
||||
raise_on_error=True,
|
||||
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
|
||||
)
|
||||
mock__encode_sign.assert_called_once_with(data_name, ids_values)
|
||||
if data_name in bus.normalized_data:
|
||||
mock__unnormalize.assert_called_once_with(ids_values)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_name, ids_values",
|
||||
[
|
||||
("Goal_Position", {1: 1337, 2: 42, 3: 4016}),
|
||||
("Goal_Velocity", {1: 50, 2: 83, 3: 2777}),
|
||||
("Lock", {1: 0, 2: 0, 3: 1}),
|
||||
],
|
||||
ids=["Goal_Position", "Goal_Velocity", "Lock"],
|
||||
)
|
||||
def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors):
|
||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||
values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
|
||||
|
||||
with (
|
||||
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
|
||||
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
|
||||
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
|
||||
):
|
||||
bus.sync_write(data_name, values)
|
||||
|
||||
mock__sync_write.assert_called_once_with(
|
||||
addr,
|
||||
length,
|
||||
ids_values,
|
||||
num_retry=0,
|
||||
raise_on_error=True,
|
||||
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
|
||||
)
|
||||
mock__encode_sign.assert_called_once_with(data_name, ids_values)
|
||||
if data_name in bus.normalized_data:
|
||||
mock__unnormalize.assert_called_once_with(ids_values)
|
||||
@@ -37,7 +37,6 @@ def test_diffuser_scheduler(optimizer):
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
@@ -56,7 +55,6 @@ def test_vqbet_scheduler(optimizer):
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
@@ -77,7 +75,6 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from pathlib import Path
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot import available_policies
|
||||
@@ -408,7 +409,16 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
4. Check that this test now passes.
|
||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
||||
|
||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
||||
https://github.com/huggingface/lerobot/pull/1127.
|
||||
|
||||
"""
|
||||
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
|
||||
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
|
||||
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
|
||||
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for physical robots and their mocked versions.
|
||||
If the physical robots are not connected to the computer, or not working,
|
||||
the test will be skipped.
|
||||
|
||||
Example of running a specific test:
|
||||
```bash
|
||||
pytest -sx tests/test_control_robot.py::test_teleoperate
|
||||
```
|
||||
|
||||
Example of running test on real robots connected to the computer:
|
||||
```bash
|
||||
pytest -sx 'tests/test_control_robot.py::test_teleoperate[koch-False]'
|
||||
pytest -sx 'tests/test_control_robot.py::test_teleoperate[koch_bimanual-False]'
|
||||
pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-False]'
|
||||
```
|
||||
|
||||
Example of running test on a mocked version of robots:
|
||||
```bash
|
||||
pytest -sx 'tests/test_control_robot.py::test_teleoperate[koch-True]'
|
||||
pytest -sx 'tests/test_control_robot.py::test_teleoperate[koch_bimanual-True]'
|
||||
pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
|
||||
```
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
RecordControlConfig,
|
||||
ReplayControlConfig,
|
||||
TeleoperateControlConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from tests.robots.test_robots import make_robot
|
||||
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_teleoperate(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
teleoperate(robot, TeleoperateControlConfig(teleop_time_s=1))
|
||||
teleoperate(robot, TeleoperateControlConfig(fps=30, teleop_time_s=1))
|
||||
teleoperate(robot, TeleoperateControlConfig(fps=60, teleop_time_s=1))
|
||||
del robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_calibrate(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
calibration_dir = tmp_path / robot_type
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
calib_cfg = CalibrateControlConfig(arms=robot.available_arms)
|
||||
calibrate(robot, calib_cfg)
|
||||
del robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_record_without_cameras(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
# Avoid using cameras
|
||||
robot_kwargs["cameras"] = {}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
single_task=single_task,
|
||||
root=root,
|
||||
fps=30,
|
||||
warmup_time_s=0.1,
|
||||
episode_time_s=1,
|
||||
reset_time_s=0.1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
record(robot, rec_cfg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
repo_id = "lerobot_test/debug"
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
single_task=single_task,
|
||||
root=root,
|
||||
fps=1,
|
||||
warmup_time_s=0.1,
|
||||
episode_time_s=1,
|
||||
reset_time_s=0.1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
# TODO(rcadene, aliberts): test video=True
|
||||
video=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
dataset = record(robot, rec_cfg)
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
replay(robot, replay_cfg)
|
||||
|
||||
policy_cfg = ACTConfig()
|
||||
policy = make_policy(policy_cfg, ds_meta=dataset.meta)
|
||||
|
||||
out_dir = tmp_path / "logger"
|
||||
|
||||
pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model"
|
||||
policy.save_pretrained(pretrained_policy_path)
|
||||
|
||||
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
|
||||
# during inference, to reach constant fps, so we test this here.
|
||||
if robot_type == "aloha":
|
||||
num_image_writer_processes = 1
|
||||
|
||||
# `multiprocessing.set_start_method("spawn", force=True)` avoids a hanging issue
|
||||
# before exiting pytest. However, it outputs the following error in the log:
|
||||
# Traceback (most recent call last):
|
||||
# File "<string>", line 1, in <module>
|
||||
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
|
||||
# exitcode = _main(fd, parent_sentinel)
|
||||
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
|
||||
# self = reduction.pickle.load(from_parent)
|
||||
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/synchronize.py", line 110, in __setstate__
|
||||
# self._semlock = _multiprocessing.SemLock._rebuild(*state)
|
||||
# FileNotFoundError: [Errno 2] No such file or directory
|
||||
# TODO(rcadene, aliberts): fix FileNotFoundError in multiprocessing
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
else:
|
||||
num_image_writer_processes = 0
|
||||
|
||||
eval_repo_id = "lerobot/eval_debug"
|
||||
eval_root = tmp_path / "data" / eval_repo_id
|
||||
|
||||
rec_eval_cfg = RecordControlConfig(
|
||||
repo_id=eval_repo_id,
|
||||
root=eval_root,
|
||||
single_task=single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0.1,
|
||||
episode_time_s=1,
|
||||
reset_time_s=0.1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)
|
||||
rec_eval_cfg.policy.pretrained_path = pretrained_policy_path
|
||||
|
||||
dataset = record(robot, rec_eval_cfg)
|
||||
assert dataset.num_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
del robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_resume_record(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
single_task=single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
num_episodes=1,
|
||||
)
|
||||
|
||||
dataset = record(robot, rec_cfg)
|
||||
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
|
||||
|
||||
with pytest.raises(FileExistsError):
|
||||
# Dataset already exists, but resume=False by default
|
||||
record(robot, rec_cfg)
|
||||
|
||||
rec_cfg.resume = True
|
||||
dataset = record(robot, rec_cfg)
|
||||
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = True
|
||||
mock_events["stop_recording"] = False
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
single_task=single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
||||
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = False
|
||||
mock_events["stop_recording"] = False
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
single_task=single_task,
|
||||
fps=2,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
|
||||
)
|
||||
@require_robot
|
||||
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = False
|
||||
mock_events["stop_recording"] = True
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
single_task=single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
reset_time_s=0.1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
@@ -1,144 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for physical robots and their mocked versions.
|
||||
If the physical robots are not connected to the computer, or not working,
|
||||
the test will be skipped.
|
||||
|
||||
Example of running a specific test:
|
||||
```bash
|
||||
pytest -sx tests/test_robots.py::test_robot
|
||||
```
|
||||
|
||||
Example of running test on real robots connected to the computer:
|
||||
```bash
|
||||
pytest -sx 'tests/test_robots.py::test_robot[koch-False]'
|
||||
pytest -sx 'tests/test_robots.py::test_robot[koch_bimanual-False]'
|
||||
pytest -sx 'tests/test_robots.py::test_robot[aloha-False]'
|
||||
```
|
||||
|
||||
Example of running test on a mocked version of robots:
|
||||
```bash
|
||||
pytest -sx 'tests/test_robots.py::test_robot[koch-True]'
|
||||
pytest -sx 'tests/test_robots.py::test_robot[koch_bimanual-True]'
|
||||
pytest -sx 'tests/test_robots.py::test_robot[aloha-True]'
|
||||
```
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_robot(tmp_path, request, robot_type, mock):
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
# TODO(rcadene): add compatibility with other robots
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if robot_type == "aloha" and mock:
|
||||
# To simplify unit test, we do not rerun manual calibration for Aloha mock=True.
|
||||
# Instead, we use the files from '.cache/calibration/aloha_default'
|
||||
pass
|
||||
else:
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
|
||||
# Test using robot before connecting raises an error
|
||||
robot = make_robot(**robot_kwargs)
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
robot.teleop_step()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
robot.teleop_step(record_data=True)
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
robot.capture_observation()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
robot.send_action(None)
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
robot.disconnect()
|
||||
|
||||
# Test deleting the object without connecting first
|
||||
del robot
|
||||
|
||||
# Test connecting (triggers manual calibration)
|
||||
robot = make_robot(**robot_kwargs)
|
||||
robot.connect()
|
||||
assert robot.is_connected
|
||||
|
||||
# Test connecting twice raises an error
|
||||
with pytest.raises(RobotDeviceAlreadyConnectedError):
|
||||
robot.connect()
|
||||
|
||||
# TODO(rcadene, aliberts): Test disconnecting with `__del__` instead of `disconnect`
|
||||
# del robot
|
||||
robot.disconnect()
|
||||
|
||||
# Test teleop can run
|
||||
robot = make_robot(**robot_kwargs)
|
||||
robot.connect()
|
||||
robot.teleop_step()
|
||||
|
||||
# Test data recorded during teleop are well formatted
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
# State
|
||||
assert "observation.state" in observation
|
||||
assert isinstance(observation["observation.state"], torch.Tensor)
|
||||
assert observation["observation.state"].ndim == 1
|
||||
dim_state = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
|
||||
assert observation["observation.state"].shape[0] == dim_state
|
||||
# Cameras
|
||||
for name in robot.cameras:
|
||||
assert f"observation.images.{name}" in observation
|
||||
assert isinstance(observation[f"observation.images.{name}"], torch.Tensor)
|
||||
assert observation[f"observation.images.{name}"].ndim == 3
|
||||
# Action
|
||||
assert "action" in action
|
||||
assert isinstance(action["action"], torch.Tensor)
|
||||
assert action["action"].ndim == 1
|
||||
dim_action = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
|
||||
assert action["action"].shape[0] == dim_action
|
||||
# TODO(rcadene): test if observation and action data are returned as expected
|
||||
|
||||
# Test capture_observation can run and observation returned are the same (since the arm didnt move)
|
||||
captured_observation = robot.capture_observation()
|
||||
assert set(captured_observation.keys()) == set(observation.keys())
|
||||
for name in captured_observation:
|
||||
if "image" in name:
|
||||
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
||||
continue
|
||||
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
|
||||
assert captured_observation[name].shape == observation[name].shape
|
||||
|
||||
# Test send_action can run
|
||||
robot.send_action(action["action"])
|
||||
|
||||
# Test disconnecting
|
||||
robot.disconnect()
|
||||
assert not robot.is_connected
|
||||
for name in robot.follower_arms:
|
||||
assert not robot.follower_arms[name].is_connected
|
||||
for name in robot.leader_arms:
|
||||
assert not robot.leader_arms[name].is_connected
|
||||
for name in robot.cameras:
|
||||
assert not robot.cameras[name].is_connected
|
||||
95
tests/robots/test_so100_follower.py
Normal file
95
tests/robots/test_so100_follower.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.robots.so100_follower import (
|
||||
SO100Follower,
|
||||
SO100FollowerConfig,
|
||||
)
|
||||
|
||||
|
||||
def _make_bus_mock() -> MagicMock:
|
||||
"""Return a bus mock with just the attributes used by the robot."""
|
||||
bus = MagicMock(name="FeetechBusMock")
|
||||
bus.is_connected = False
|
||||
|
||||
def _connect():
|
||||
bus.is_connected = True
|
||||
|
||||
def _disconnect(_disable=True):
|
||||
bus.is_connected = False
|
||||
|
||||
bus.connect.side_effect = _connect
|
||||
bus.disconnect.side_effect = _disconnect
|
||||
|
||||
@contextmanager
|
||||
def _dummy_cm():
|
||||
yield
|
||||
|
||||
bus.torque_disabled.side_effect = _dummy_cm
|
||||
|
||||
return bus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def follower():
|
||||
bus_mock = _make_bus_mock()
|
||||
|
||||
def _bus_side_effect(*_args, **kwargs):
|
||||
bus_mock.motors = kwargs["motors"]
|
||||
motors_order: list[str] = list(bus_mock.motors)
|
||||
|
||||
bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)}
|
||||
bus_mock.sync_write.return_value = None
|
||||
bus_mock.write.return_value = None
|
||||
bus_mock.disable_torque.return_value = None
|
||||
bus_mock.enable_torque.return_value = None
|
||||
bus_mock.is_calibrated = True
|
||||
return bus_mock
|
||||
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
|
||||
side_effect=_bus_side_effect,
|
||||
),
|
||||
patch.object(SO100Follower, "configure", lambda self: None),
|
||||
):
|
||||
cfg = SO100FollowerConfig(port="/dev/null")
|
||||
robot = SO100Follower(cfg)
|
||||
yield robot
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_connect_disconnect(follower):
|
||||
assert not follower.is_connected
|
||||
|
||||
follower.connect()
|
||||
assert follower.is_connected
|
||||
|
||||
follower.disconnect()
|
||||
assert not follower.is_connected
|
||||
|
||||
|
||||
def test_get_observation(follower):
|
||||
follower.connect()
|
||||
obs = follower.get_observation()
|
||||
|
||||
expected_keys = {f"{m}.pos" for m in follower.bus.motors}
|
||||
assert set(obs.keys()) == expected_keys
|
||||
|
||||
for idx, motor in enumerate(follower.bus.motors, 1):
|
||||
assert obs[f"{motor}.pos"] == idx
|
||||
|
||||
|
||||
def test_send_action(follower):
|
||||
follower.connect()
|
||||
|
||||
action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)}
|
||||
returned = follower.send_action(action)
|
||||
|
||||
assert returned == action
|
||||
|
||||
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)}
|
||||
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)
|
||||
97
tests/test_control_robot.py
Normal file
97
tests/test_control_robot.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import time
|
||||
|
||||
from lerobot.calibrate import CalibrateConfig, calibrate
|
||||
from lerobot.record import DatasetRecordConfig, RecordConfig, record
|
||||
from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay
|
||||
from lerobot.teleoperate import TeleoperateConfig, teleoperate
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.mocks.mock_teleop import MockTeleopConfig
|
||||
|
||||
|
||||
def test_calibrate():
|
||||
robot_cfg = MockRobotConfig()
|
||||
cfg = CalibrateConfig(robot=robot_cfg)
|
||||
calibrate(cfg)
|
||||
|
||||
|
||||
def test_teleoperate():
|
||||
robot_cfg = MockRobotConfig()
|
||||
teleop_cfg = MockTeleopConfig()
|
||||
expected_duration = 0.1
|
||||
cfg = TeleoperateConfig(
|
||||
robot=robot_cfg,
|
||||
teleop=teleop_cfg,
|
||||
teleop_time_s=expected_duration,
|
||||
)
|
||||
start = time.perf_counter()
|
||||
teleoperate(cfg)
|
||||
actual_duration = time.perf_counter() - start
|
||||
|
||||
assert actual_duration <= expected_duration * 1.1
|
||||
|
||||
|
||||
def test_record_and_resume(tmp_path):
|
||||
robot_cfg = MockRobotConfig()
|
||||
teleop_cfg = MockTeleopConfig()
|
||||
dataset_cfg = DatasetRecordConfig(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
single_task="Dummy task",
|
||||
root=tmp_path / "record",
|
||||
num_episodes=1,
|
||||
episode_time_s=0.1,
|
||||
reset_time_s=0,
|
||||
push_to_hub=False,
|
||||
)
|
||||
cfg = RecordConfig(
|
||||
robot=robot_cfg,
|
||||
dataset=dataset_cfg,
|
||||
teleop=teleop_cfg,
|
||||
play_sounds=False,
|
||||
)
|
||||
|
||||
dataset = record(cfg)
|
||||
|
||||
assert dataset.fps == 30
|
||||
assert dataset.meta.total_episodes == dataset.num_episodes == 1
|
||||
assert dataset.meta.total_frames == dataset.num_frames == 3
|
||||
assert dataset.meta.total_tasks == 1
|
||||
|
||||
cfg.resume = True
|
||||
dataset = record(cfg)
|
||||
|
||||
assert dataset.meta.total_episodes == dataset.num_episodes == 2
|
||||
assert dataset.meta.total_frames == dataset.num_frames == 6
|
||||
assert dataset.meta.total_tasks == 1
|
||||
|
||||
|
||||
def test_record_and_replay(tmp_path):
|
||||
robot_cfg = MockRobotConfig()
|
||||
teleop_cfg = MockTeleopConfig()
|
||||
record_dataset_cfg = DatasetRecordConfig(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
single_task="Dummy task",
|
||||
root=tmp_path / "record_and_replay",
|
||||
num_episodes=1,
|
||||
episode_time_s=0.1,
|
||||
push_to_hub=False,
|
||||
)
|
||||
record_cfg = RecordConfig(
|
||||
robot=robot_cfg,
|
||||
dataset=record_dataset_cfg,
|
||||
teleop=teleop_cfg,
|
||||
play_sounds=False,
|
||||
)
|
||||
replay_dataset_cfg = DatasetReplayConfig(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
episode=0,
|
||||
root=tmp_path / "record_and_replay",
|
||||
)
|
||||
replay_cfg = ReplayConfig(
|
||||
robot=robot_cfg,
|
||||
dataset=replay_dataset_cfg,
|
||||
play_sounds=False,
|
||||
)
|
||||
|
||||
record(record_cfg)
|
||||
replay(replay_cfg)
|
||||
110
tests/utils.py
110
tests/utils.py
@@ -13,20 +13,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
||||
from lerobot.common.cameras import Camera
|
||||
from lerobot.common.motors.motors_bus import MotorsBus
|
||||
from lerobot.common.motors.utils import make_motors_bus as make_motors_bus_device
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||
@@ -190,46 +187,6 @@ def require_package(package_name):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_robot(func):
|
||||
"""
|
||||
Decorator that skips the test if a robot is not available
|
||||
|
||||
The decorated function must have two arguments `request` and `robot_type`.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
@pytest.mark.parametrize(
|
||||
"robot_type", ["koch", "aloha"]
|
||||
)
|
||||
@require_robot
|
||||
def test_require_robot(request, robot_type):
|
||||
pass
|
||||
```
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Access the pytest request context to get the is_robot_available fixture
|
||||
request = kwargs.get("request")
|
||||
robot_type = kwargs.get("robot_type")
|
||||
mock = kwargs.get("mock")
|
||||
|
||||
if robot_type is None:
|
||||
raise ValueError("The 'robot_type' must be an argument of the test function.")
|
||||
if request is None:
|
||||
raise ValueError("The 'request' fixture must be an argument of the test function.")
|
||||
if mock is None:
|
||||
raise ValueError("The 'mock' variable must be an argument of the test function.")
|
||||
|
||||
# Run test with a real robot. Skip test if robot connection fails.
|
||||
if not mock and not request.getfixturevalue("is_robot_available"):
|
||||
pytest.skip(f"A {robot_type} robot is not available.")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_camera(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -253,64 +210,23 @@ def require_camera(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_motor(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Access the pytest request context to get the is_motor_available fixture
|
||||
request = kwargs.get("request")
|
||||
motor_type = kwargs.get("motor_type")
|
||||
mock = kwargs.get("mock")
|
||||
|
||||
if request is None:
|
||||
raise ValueError("The 'request' fixture must be an argument of the test function.")
|
||||
if motor_type is None:
|
||||
raise ValueError("The 'motor_type' must be an argument of the test function.")
|
||||
if mock is None:
|
||||
raise ValueError("The 'mock' variable must be an argument of the test function.")
|
||||
|
||||
if not mock and not request.getfixturevalue("is_motor_available"):
|
||||
pytest.skip(f"A {motor_type} motor is not available.")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def mock_calibration_dir(calibration_dir):
|
||||
# TODO(rcadene): remove this hack
|
||||
# calibration file produced with Moss v1, but works with Koch, Koch bimanual and SO-100
|
||||
example_calib = {
|
||||
"homing_offset": [-1416, -845, 2130, 2872, 1950, -2211],
|
||||
"drive_mode": [0, 0, 1, 1, 1, 0],
|
||||
"start_pos": [1442, 843, 2166, 2849, 1988, 1835],
|
||||
"end_pos": [2440, 1869, -1106, -1848, -926, 3235],
|
||||
"calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"],
|
||||
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
}
|
||||
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
|
||||
with open(calibration_dir / "main_follower.json", "w") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "main_leader.json", "w") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "left_follower.json", "w") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "left_leader.json", "w") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "right_follower.json", "w") as f:
|
||||
json.dump(example_calib, f)
|
||||
with open(calibration_dir / "right_leader.json", "w") as f:
|
||||
json.dump(example_calib, f)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
||||
def make_camera(camera_type: str, **kwargs) -> Camera:
|
||||
if camera_type == "opencv":
|
||||
camera_index = kwargs.pop("camera_index", OPENCV_CAMERA_INDEX)
|
||||
return make_camera_device(camera_type, camera_index=camera_index, **kwargs)
|
||||
kwargs["camera_index"] = camera_index
|
||||
from lerobot.common.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
|
||||
|
||||
config = OpenCVCameraConfig(**kwargs)
|
||||
return OpenCVCamera(config)
|
||||
|
||||
elif camera_type == "intelrealsense":
|
||||
serial_number = kwargs.pop("serial_number", INTELREALSENSE_SERIAL_NUMBER)
|
||||
return make_camera_device(camera_type, serial_number=serial_number, **kwargs)
|
||||
kwargs["serial_number"] = serial_number
|
||||
from lerobot.common.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
|
||||
|
||||
config = RealSenseCameraConfig(**kwargs)
|
||||
return RealSenseCamera(config)
|
||||
else:
|
||||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||
|
||||
|
||||
155
tests/utils/test_encoding_utils.py
Normal file
155
tests/utils/test_encoding_utils.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.encoding_utils import (
|
||||
decode_sign_magnitude,
|
||||
decode_twos_complement,
|
||||
encode_sign_magnitude,
|
||||
encode_twos_complement,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, sign_bit_index, expected",
|
||||
[
|
||||
(5, 4, 5),
|
||||
(0, 4, 0),
|
||||
(7, 3, 7),
|
||||
(-1, 4, 17),
|
||||
(-8, 4, 24),
|
||||
(-3, 3, 11),
|
||||
],
|
||||
)
|
||||
def test_encode_sign_magnitude(value, sign_bit_index, expected):
|
||||
assert encode_sign_magnitude(value, sign_bit_index) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"encoded, sign_bit_index, expected",
|
||||
[
|
||||
(5, 4, 5),
|
||||
(0, 4, 0),
|
||||
(7, 3, 7),
|
||||
(17, 4, -1),
|
||||
(24, 4, -8),
|
||||
(11, 3, -3),
|
||||
],
|
||||
)
|
||||
def test_decode_sign_magnitude(encoded, sign_bit_index, expected):
|
||||
assert decode_sign_magnitude(encoded, sign_bit_index) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"encoded, sign_bit_index",
|
||||
[
|
||||
(16, 4),
|
||||
(-9, 3),
|
||||
],
|
||||
)
|
||||
def test_encode_raises_on_overflow(encoded, sign_bit_index):
|
||||
with pytest.raises(ValueError):
|
||||
encode_sign_magnitude(encoded, sign_bit_index)
|
||||
|
||||
|
||||
def test_encode_decode_sign_magnitude():
|
||||
for sign_bit_index in range(2, 6):
|
||||
max_val = (1 << sign_bit_index) - 1
|
||||
for value in range(-max_val, max_val + 1):
|
||||
encoded = encode_sign_magnitude(value, sign_bit_index)
|
||||
decoded = decode_sign_magnitude(encoded, sign_bit_index)
|
||||
assert decoded == value, f"Failed at value={value}, index={sign_bit_index}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, n_bytes, expected",
|
||||
[
|
||||
(0, 1, 0),
|
||||
(5, 1, 5),
|
||||
(-1, 1, 255),
|
||||
(-128, 1, 128),
|
||||
(-2, 1, 254),
|
||||
(127, 1, 127),
|
||||
(0, 2, 0),
|
||||
(5, 2, 5),
|
||||
(-1, 2, 65_535),
|
||||
(-32_768, 2, 32_768),
|
||||
(-2, 2, 65_534),
|
||||
(32_767, 2, 32_767),
|
||||
(0, 4, 0),
|
||||
(5, 4, 5),
|
||||
(-1, 4, 4_294_967_295),
|
||||
(-2_147_483_648, 4, 2_147_483_648),
|
||||
(-2, 4, 4_294_967_294),
|
||||
(2_147_483_647, 4, 2_147_483_647),
|
||||
],
|
||||
)
|
||||
def test_encode_twos_complement(value, n_bytes, expected):
|
||||
assert encode_twos_complement(value, n_bytes) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, n_bytes, expected",
|
||||
[
|
||||
(0, 1, 0),
|
||||
(5, 1, 5),
|
||||
(255, 1, -1),
|
||||
(128, 1, -128),
|
||||
(254, 1, -2),
|
||||
(127, 1, 127),
|
||||
(0, 2, 0),
|
||||
(5, 2, 5),
|
||||
(65_535, 2, -1),
|
||||
(32_768, 2, -32_768),
|
||||
(65_534, 2, -2),
|
||||
(32_767, 2, 32_767),
|
||||
(0, 4, 0),
|
||||
(5, 4, 5),
|
||||
(4_294_967_295, 4, -1),
|
||||
(2_147_483_648, 4, -2_147_483_648),
|
||||
(4_294_967_294, 4, -2),
|
||||
(2_147_483_647, 4, 2_147_483_647),
|
||||
],
|
||||
)
|
||||
def test_decode_twos_complement(value, n_bytes, expected):
|
||||
assert decode_twos_complement(value, n_bytes) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, n_bytes",
|
||||
[
|
||||
(-129, 1),
|
||||
(128, 1),
|
||||
(-32_769, 2),
|
||||
(32_768, 2),
|
||||
(-2_147_483_649, 4),
|
||||
(2_147_483_648, 4),
|
||||
],
|
||||
)
|
||||
def test_encode_twos_complement_out_of_range(value, n_bytes):
|
||||
with pytest.raises(ValueError):
|
||||
encode_twos_complement(value, n_bytes)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, n_bytes",
|
||||
[
|
||||
(-128, 1),
|
||||
(-1, 1),
|
||||
(0, 1),
|
||||
(1, 1),
|
||||
(127, 1),
|
||||
(-32_768, 2),
|
||||
(-1, 2),
|
||||
(0, 2),
|
||||
(1, 2),
|
||||
(32_767, 2),
|
||||
(-2_147_483_648, 4),
|
||||
(-1, 4),
|
||||
(0, 4),
|
||||
(1, 4),
|
||||
(2_147_483_647, 4),
|
||||
],
|
||||
)
|
||||
def test_encode_decode_twos_complement(value, n_bytes):
|
||||
encoded = encode_twos_complement(value, n_bytes)
|
||||
decoded = decode_twos_complement(encoded, n_bytes)
|
||||
assert decoded == value, f"Failed at value={value}, n_bytes={n_bytes}"
|
||||
Reference in New Issue
Block a user