Merge remote-tracking branch 'origin/main' into user/rcadene/2024_09_10_train_aloha
This commit is contained in:
@@ -8,6 +8,10 @@ CAP_PROP_FRAME_HEIGHT = 4
|
||||
COLOR_RGB2BGR = 4
|
||||
COLOR_BGR2RGB = 4
|
||||
|
||||
ROTATE_90_COUNTERCLOCKWISE = 2
|
||||
ROTATE_90_CLOCKWISE = 0
|
||||
ROTATE_180 = 1
|
||||
|
||||
|
||||
@cache
|
||||
def _generate_image(width: int, height: int):
|
||||
@@ -21,6 +25,19 @@ def cvtColor(color_image, color_convertion): # noqa: N802
|
||||
raise NotImplementedError(color_convertion)
|
||||
|
||||
|
||||
def rotate(color_image, rotation):
|
||||
if rotation is None:
|
||||
return color_image
|
||||
elif rotation == ROTATE_90_CLOCKWISE:
|
||||
return np.rot90(color_image, k=1)
|
||||
elif rotation == ROTATE_180:
|
||||
return np.rot90(color_image, k=2)
|
||||
elif rotation == ROTATE_90_COUNTERCLOCKWISE:
|
||||
return np.rot90(color_image, k=3)
|
||||
else:
|
||||
raise NotImplementedError(rotation)
|
||||
|
||||
|
||||
class VideoCapture:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._mock_dict = {
|
||||
|
||||
@@ -3,33 +3,31 @@ import enum
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RSStream(enum.Enum):
|
||||
class stream(enum.Enum): # noqa: N801
|
||||
color = 0
|
||||
depth = 1
|
||||
|
||||
|
||||
class RSFormat(enum.Enum):
|
||||
class format(enum.Enum): # noqa: N801
|
||||
rgb8 = 0
|
||||
z16 = 1
|
||||
|
||||
|
||||
class RSConfig:
|
||||
class config: # noqa: N801
|
||||
def enable_device(self, device_id: str):
|
||||
self.device_enabled = device_id
|
||||
|
||||
def enable_stream(
|
||||
self, stream_type: RSStream, width=None, height=None, color_format: RSFormat = None, fps=None
|
||||
):
|
||||
def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
|
||||
self.stream_type = stream_type
|
||||
# Overwrite default values when possible
|
||||
self.width = 848 if width is None else width
|
||||
self.height = 480 if height is None else height
|
||||
self.color_format = RSFormat.rgb8 if color_format is None else color_format
|
||||
self.color_format = format.rgb8 if color_format is None else color_format
|
||||
self.fps = 30 if fps is None else fps
|
||||
|
||||
|
||||
class RSColorProfile:
|
||||
def __init__(self, config: RSConfig):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def fps(self):
|
||||
@@ -43,7 +41,7 @@ class RSColorProfile:
|
||||
|
||||
|
||||
class RSColorStream:
|
||||
def __init__(self, config: RSConfig):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def as_video_stream_profile(self):
|
||||
@@ -51,20 +49,20 @@ class RSColorStream:
|
||||
|
||||
|
||||
class RSProfile:
|
||||
def __init__(self, config: RSConfig):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_stream(self, color_format: RSFormat):
|
||||
def get_stream(self, color_format):
|
||||
del color_format # unused
|
||||
return RSColorStream(self.config)
|
||||
|
||||
|
||||
class RSPipeline:
|
||||
class pipeline: # noqa: N801
|
||||
def __init__(self):
|
||||
self.started = False
|
||||
self.config = None
|
||||
|
||||
def start(self, config: RSConfig):
|
||||
def start(self, config):
|
||||
self.started = True
|
||||
self.config = config
|
||||
return RSProfile(self.config)
|
||||
@@ -81,7 +79,7 @@ class RSPipeline:
|
||||
|
||||
|
||||
class RSFrames:
|
||||
def __init__(self, config: RSConfig):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_color_frame(self):
|
||||
@@ -92,7 +90,7 @@ class RSFrames:
|
||||
|
||||
|
||||
class RSColorFrame:
|
||||
def __init__(self, config: RSConfig):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_data(self):
|
||||
@@ -103,7 +101,7 @@ class RSColorFrame:
|
||||
|
||||
|
||||
class RSDepthFrame:
|
||||
def __init__(self, config: RSConfig):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_data(self):
|
||||
@@ -120,7 +118,7 @@ class RSDevice:
|
||||
return "123456789"
|
||||
|
||||
|
||||
class RSContext:
|
||||
class context: # noqa: N801
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -128,7 +126,10 @@ class RSContext:
|
||||
return [RSDevice()]
|
||||
|
||||
|
||||
class RSCameraInfo:
|
||||
class camera_info: # noqa: N801
|
||||
# fake name
|
||||
name = "Intel RealSense D435I"
|
||||
|
||||
def __init__(self, serial_number):
|
||||
del serial_number
|
||||
pass
|
||||
|
||||
@@ -120,6 +120,41 @@ def test_camera(request, camera_type, mock):
|
||||
)
|
||||
del camera
|
||||
|
||||
# Test acquiring a rotated image
|
||||
camera = make_camera(camera_type, mock=mock)
|
||||
camera.connect()
|
||||
ori_color_image = camera.read()
|
||||
del camera
|
||||
|
||||
for rotation in [None, 90, 180, -90]:
|
||||
camera = make_camera(camera_type, rotation=rotation, mock=mock)
|
||||
camera.connect()
|
||||
|
||||
if mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
if rotation is None:
|
||||
manual_rot_img = ori_color_image
|
||||
assert camera.rotation is None
|
||||
elif rotation == 90:
|
||||
manual_rot_img = np.rot90(color_image, k=1)
|
||||
assert camera.rotation == cv2.ROTATE_90_CLOCKWISE
|
||||
elif rotation == 180:
|
||||
manual_rot_img = np.rot90(color_image, k=2)
|
||||
assert camera.rotation == cv2.ROTATE_180
|
||||
elif rotation == -90:
|
||||
manual_rot_img = np.rot90(color_image, k=3)
|
||||
assert camera.rotation == cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
|
||||
rot_color_image = camera.read()
|
||||
|
||||
np.testing.assert_allclose(
|
||||
rot_color_image, manual_rot_img, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
|
||||
)
|
||||
del camera
|
||||
|
||||
# TODO(rcadene): Add a test for a camera that doesnt support fps=60 and raises an OSError
|
||||
# TODO(rcadene): Add a test for a camera that supports fps=60
|
||||
|
||||
@@ -152,4 +187,5 @@ def test_save_images_from_cameras(tmpdir, request, camera_type, mock):
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
|
||||
|
||||
save_images_from_cameras(tmpdir, record_time_s=1, mock=mock)
|
||||
# Small `record_time_s` to speedup unit tests
|
||||
save_images_from_cameras(tmpdir, record_time_s=0.02, mock=mock)
|
||||
|
||||
Reference in New Issue
Block a user