@@ -13,9 +13,12 @@ from functools import cache
|
||||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer
|
||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import get_features_from_robot
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
@@ -227,7 +230,7 @@ def control_loop(
|
||||
control_time_s=None,
|
||||
teleoperate=False,
|
||||
display_cameras=False,
|
||||
dataset=None,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
events=None,
|
||||
policy=None,
|
||||
device=None,
|
||||
@@ -247,7 +250,7 @@ def control_loop(
|
||||
if teleoperate and policy is not None:
|
||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset["fps"] != fps:
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
timestamp = 0
|
||||
@@ -268,7 +271,8 @@ def control_loop(
|
||||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
add_frame(dataset, observation, action)
|
||||
frame = {**observation, **action}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
@@ -336,3 +340,24 @@ def sanity_check_dataset_name(repo_id, policy):
|
||||
raise ValueError(
|
||||
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})."
|
||||
)
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||||
) -> None:
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
for field, dataset_value, present_value in fields:
|
||||
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
|
||||
if diff:
|
||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||
|
||||
if mismatches:
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user