From 00e9f61509b17e58e201f2ba8a52ead720c4ea3a Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 3 Jun 2025 18:42:41 +0200 Subject: [PATCH] (fix): test --- lerobot/common/utils/control_utils.py | 207 ++++++++++++++++++++++++++ tests/datasets/test_datasets.py | 56 +++---- 2 files changed, 230 insertions(+), 33 deletions(-) diff --git a/lerobot/common/utils/control_utils.py b/lerobot/common/utils/control_utils.py index e69de29bb..ca3aa6ae4 100644 --- a/lerobot/common/utils/control_utils.py +++ b/lerobot/common/utils/control_utils.py @@ -0,0 +1,207 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +######################################################################################## +# Utilities +######################################################################################## + + +import logging +import traceback +from contextlib import nullcontext +from copy import copy +from functools import cache + +import numpy as np +import torch +from deepdiff import DeepDiff +from termcolor import colored + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import DEFAULT_FEATURES +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.robots import Robot + + +def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): + log_items = [] + if episode_index is not None: + log_items.append(f"ep:{episode_index}") + if frame_index is not None: + log_items.append(f"frame:{frame_index}") + + def log_dt(shortname, dt_val_s): + nonlocal log_items, fps + info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" + if fps is not None: + actual_fps = 1 / dt_val_s + if actual_fps < fps - 1: + info_str = colored(info_str, "yellow") + log_items.append(info_str) + + # total step time displayed in milliseconds and its frequency + log_dt("dt", dt_s) + + # TODO(aliberts): move robot-specific logs logic in robot.print_logs() + if not robot.robot_type.startswith("stretch"): + for name in robot.leader_arms: + key = f"read_leader_{name}_pos_dt_s" + if key in robot.logs: + log_dt("dtRlead", robot.logs[key]) + + for name in robot.follower_arms: + key = f"write_follower_{name}_goal_pos_dt_s" + if key in robot.logs: + log_dt("dtWfoll", robot.logs[key]) + + key = f"read_follower_{name}_pos_dt_s" + if key in robot.logs: + log_dt("dtRfoll", robot.logs[key]) + + for name in robot.cameras: + key = f"read_camera_{name}_dt_s" + if key in robot.logs: + log_dt(f"dtR{name}", robot.logs[key]) + + info_str = " ".join(log_items) + logging.info(info_str) + + +@cache +def is_headless(): + """Detects if python is running without a monitor.""" + try: + import pynput # noqa + + return False + except Exception: + print( + "Error trying to import pynput. Switching to headless mode. " + "As a result, the video stream from the cameras won't be shown, " + "and you won't be able to change the control flow with keyboards. " + "For more info, see traceback below.\n" + ) + traceback.print_exc() + print() + return True + + +def predict_action( + observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool +): + observation = copy(observation) + with ( + torch.inference_mode(), + torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + ): + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension + for name in observation: + observation[name] = torch.from_numpy(observation[name]) + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + + # Remove batch dimension + action = action.squeeze(0) + + # Move to cpu, if not already the case + action = action.to("cpu") + + return action + + +def init_keyboard_listener(): + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + events = {} + events["exit_early"] = False + events["rerecord_episode"] = False + events["stop_recording"] = False + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + events["stop_recording"] = True + events["exit_early"] = True + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events + + +def sanity_check_dataset_name(repo_id, policy_cfg): + _, dataset_name = repo_id.split("/") + # either repo_id doesnt start with "eval_" and there is no policy + # or repo_id starts with "eval_" and there is a policy + + # Check if dataset_name starts with "eval_" but policy is missing + if dataset_name.startswith("eval_") and policy_cfg is None: + raise ValueError( + f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})." + ) + + # Check if dataset_name does not start with "eval_" but policy is provided + if not dataset_name.startswith("eval_") and policy_cfg is not None: + raise ValueError( + f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})." + ) + + +def sanity_check_dataset_robot_compatibility( + dataset: LeRobotDataset, robot: Robot, fps: int, features: dict +) -> None: + fields = [ + ("robot_type", dataset.meta.robot_type, robot.robot_type), + ("fps", dataset.fps, fps), + ("features", dataset.features, {**features, **DEFAULT_FEATURES}), + ] + + mismatches = [] + for field, dataset_value, present_value in fields: + diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]) + if diff: + mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") + + if mismatches: + raise ValueError( + "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) + ) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 814470892..55a417c30 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -41,7 +41,6 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.envs.factory import make_env_config from lerobot.common.policies.factory import make_policy_config -from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID @@ -70,9 +69,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): objects have the same sets of attributes defined. """ # Instantiate both ways - robot = make_robot("koch", mock=True) + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} root_create = tmp_path / "create" - dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) + dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create) root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init) @@ -100,22 +99,13 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory): assert dataset.num_frames == len(dataset) -def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - with pytest.raises( - ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n" - ): - dataset.add_frame({"state": torch.randn(1)}) - - def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n" ): - dataset.add_frame({"task": "Dummy task"}) + dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task") def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): @@ -124,7 +114,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): with pytest.raises( ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n" ): - dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}) + dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task") def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): @@ -133,7 +123,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): with pytest.raises( ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n" ): - dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task") def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): @@ -143,7 +133,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): ValueError, match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"), ): - dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(1)}, task="Dummy task") def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory): @@ -155,7 +145,7 @@ def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_fact "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" ), ): - dataset.add_frame({"state": 1.0, "task": "Dummy task"}) + dataset.add_frame({"state": 1.0}, task="Dummy task") def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory): @@ -165,7 +155,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact ValueError, match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"), ): - dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"}) + dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task") def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory): @@ -177,13 +167,13 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" ), ): - dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"}) + dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task") def test_add_frame(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(1)}, task="Dummy task") dataset.save_episode() assert len(dataset) == 1 @@ -195,7 +185,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(2)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2]) @@ -204,7 +194,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4]) @@ -213,7 +203,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) @@ -222,7 +212,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) @@ -231,7 +221,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) @@ -240,7 +230,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) + dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].ndim == 0 @@ -249,7 +239,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): features = {"caption": {"dtype": "string", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) + dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task") dataset.save_episode() assert dataset[0]["caption"] == "Dummy caption" @@ -264,7 +254,7 @@ def test_add_frame_image_wrong_shape(image_dataset): ), ): c, h, w = DUMMY_CHW - dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"}) + dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task") def test_add_frame_image_wrong_range(image_dataset): @@ -277,14 +267,14 @@ def test_add_frame_image_wrong_range(image_dataset): Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`. """ dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"}) + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task") with pytest.raises(FileNotFoundError): dataset.save_episode() def test_add_frame_image(image_dataset): dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task") dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -292,7 +282,7 @@ def test_add_frame_image(image_dataset): def test_add_frame_image_h_w_c(image_dataset): dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) + dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task") dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -301,7 +291,7 @@ def test_add_frame_image_h_w_c(image_dataset): def test_add_frame_image_uint8(image_dataset): dataset = image_dataset image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) - dataset.add_frame({"image": image, "task": "Dummy task"}) + dataset.add_frame({"image": image}, task="Dummy task") dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -310,7 +300,7 @@ def test_add_frame_image_uint8(image_dataset): def test_add_frame_image_pil(image_dataset): dataset = image_dataset image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) - dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) + dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task") dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)