Dataset v2.0 (#461)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2024-11-29 19:04:00 +01:00
committed by GitHub
parent 96c7052777
commit 32eb0cec8f
71 changed files with 6115 additions and 2235 deletions

View File

@@ -168,6 +168,7 @@ class IntelRealSenseCameraConfig:
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
use_depth: bool = False
force_hardware_reset: bool = True
rotation: int | None = None
@@ -179,6 +180,8 @@ class IntelRealSenseCameraConfig:
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
if at_least_one_is_not_none and at_least_one_is_none:
@@ -254,6 +257,7 @@ class IntelRealSenseCamera:
self.fps = config.fps
self.width = config.width
self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode
self.use_depth = config.use_depth
self.force_hardware_reset = config.force_hardware_reset

View File

@@ -192,6 +192,7 @@ class OpenCVCameraConfig:
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
rotation: int | None = None
mock: bool = False
@@ -201,6 +202,8 @@ class OpenCVCameraConfig:
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
@@ -268,6 +271,7 @@ class OpenCVCamera:
self.fps = config.fps
self.width = config.width
self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode
self.mock = config.mock

View File

@@ -13,9 +13,12 @@ from functools import cache
import cv2
import torch
import tqdm
from deepdiff import DeepDiff
from termcolor import colored
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
@@ -227,7 +230,7 @@ def control_loop(
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset=None,
dataset: LeRobotDataset | None = None,
events=None,
policy=None,
device=None,
@@ -247,7 +250,7 @@ def control_loop(
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and fps is not None and dataset["fps"] != fps:
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
@@ -268,7 +271,8 @@ def control_loop(
action = {"action": action}
if dataset is not None:
add_frame(dataset, observation, action)
frame = {**observation, **action}
dataset.add_frame(frame)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
@@ -336,3 +340,24 @@ def sanity_check_dataset_name(repo_id, policy):
raise ValueError(
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})."
)
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
) -> None:
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
]
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)

View File

@@ -226,6 +226,42 @@ class ManipulatorRobot:
self.is_connected = False
self.logs = {}
def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
action_names = self.get_motor_names(self.leader_arms)
state_names = self.get_motor_names(self.leader_arms)
return {
"action": {
"dtype": "float32",
"shape": (len(action_names),),
"names": action_names,
},
"observation.state": {
"dtype": "float32",
"shape": (len(state_names),),
"names": state_names,
},
}
@property
def features(self):
return {**self.motor_features, **self.camera_features}
@property
def has_camera(self):
return len(self.cameras) > 0

View File

@@ -11,6 +11,7 @@ def get_arm_id(name, arm_type):
class Robot(Protocol):
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
robot_type: str
features: dict
def connect(self): ...
def run_calibration(self): ...