(fix): test

This commit is contained in:
AdilZouitine
2025-06-03 18:42:41 +02:00
parent 8d4fe1ad6a
commit 00e9f61509
2 changed files with 230 additions and 33 deletions

View File

@@ -0,0 +1,207 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
########################################################################################
# Utilities
########################################################################################
import logging
import traceback
from contextlib import nullcontext
from copy import copy
from functools import cache
import numpy as np
import torch
from deepdiff import DeepDiff
from termcolor import colored
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import DEFAULT_FEATURES
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.robots import Robot
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
if frame_index is not None:
log_items.append(f"frame:{frame_index}")
def log_dt(shortname, dt_val_s):
nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
if fps is not None:
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1:
info_str = colored(info_str, "yellow")
log_items.append(info_str)
# total step time displayed in milliseconds and its frequency
log_dt("dt", dt_s)
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.startswith("stretch"):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRlead", robot.logs[key])
for name in robot.follower_arms:
key = f"write_follower_{name}_goal_pos_dt_s"
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
key = f"read_follower_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
for name in robot.cameras:
key = f"read_camera_{name}_dt_s"
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
info_str = " ".join(log_items)
logging.info(info_str)
@cache
def is_headless():
"""Detects if python is running without a monitor."""
try:
import pynput # noqa
return False
except Exception:
print(
"Error trying to import pynput. Switching to headless mode. "
"As a result, the video stream from the cameras won't be shown, "
"and you won't be able to change the control flow with keyboards. "
"For more info, see traceback below.\n"
)
traceback.print_exc()
print()
return True
def predict_action(
observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool
):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def sanity_check_dataset_name(repo_id, policy_cfg):
_, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy
# Check if dataset_name starts with "eval_" but policy is missing
if dataset_name.startswith("eval_") and policy_cfg is None:
raise ValueError(
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
)
# Check if dataset_name does not start with "eval_" but policy is provided
if not dataset_name.startswith("eval_") and policy_cfg is not None:
raise ValueError(
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
)
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, features: dict
) -> None:
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, {**features, **DEFAULT_FEATURES}),
]
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)

View File

@@ -41,7 +41,6 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
@@ -70,9 +69,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
objects have the same sets of attributes defined.
"""
# Instantiate both ways
robot = make_robot("koch", mock=True)
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
@@ -100,22 +99,13 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
assert dataset.num_frames == len(dataset)
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
):
dataset.add_frame({"state": torch.randn(1)})
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
):
dataset.add_frame({"task": "Dummy task"})
dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task")
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
@@ -124,7 +114,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task")
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
@@ -133,7 +123,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
with pytest.raises(
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
):
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task")
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
@@ -143,7 +133,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
ValueError,
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
@@ -155,7 +145,7 @@ def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_fact
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
),
):
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
dataset.add_frame({"state": 1.0}, task="Dummy task")
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
@@ -165,7 +155,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
ValueError,
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
):
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task")
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
@@ -177,13 +167,13 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
),
):
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task")
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
dataset.save_episode()
assert len(dataset) == 1
@@ -195,7 +185,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(2)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2])
@@ -204,7 +194,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4])
@@ -213,7 +203,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
@@ -222,7 +212,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
@@ -231,7 +221,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
@@ -240,7 +230,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].ndim == 0
@@ -249,7 +239,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["caption"] == "Dummy caption"
@@ -264,7 +254,7 @@ def test_add_frame_image_wrong_shape(image_dataset):
),
):
c, h, w = DUMMY_CHW
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task")
def test_add_frame_image_wrong_range(image_dataset):
@@ -277,14 +267,14 @@ def test_add_frame_image_wrong_range(image_dataset):
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
"""
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task")
with pytest.raises(FileNotFoundError):
dataset.save_episode()
def test_add_frame_image(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -292,7 +282,7 @@ def test_add_frame_image(image_dataset):
def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -301,7 +291,7 @@ def test_add_frame_image_h_w_c(image_dataset):
def test_add_frame_image_uint8(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image, "task": "Dummy task"})
dataset.add_frame({"image": image}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -310,7 +300,7 @@ def test_add_frame_image_uint8(image_dataset):
def test_add_frame_image_pil(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)