diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index bdad5cbb..eb626675 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -38,6 +38,7 @@ from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.policies.utils import validate_visual_features_consistency from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.processor.converters import ( @@ -420,20 +421,7 @@ def make_policy( # policy = torch.compile(policy, mode="reduce-overhead") if not rename_map: - expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys()) - provided_features = set(features.keys()) - if expected_features and provided_features != expected_features: - missing = expected_features - provided_features - extra = provided_features - expected_features - # TODO (jadechoghari): provide a dynamic rename map suggestion to the user. - raise ValueError( - f"Feature mismatch between dataset/environment and policy config.\n" - f"- Missing features: {sorted(missing) if missing else 'None'}\n" - f"- Extra features: {sorted(extra) if extra else 'None'}\n\n" - f"Please ensure your dataset and policy use consistent feature names.\n" - f"If your dataset uses different observation keys (e.g., cameras named differently), " - f"use the `--rename_map` argument, for example:\n" - f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", ' - f'"observation.images.top": "observation.images.camera2"}}\'' - ) + validate_visual_features_consistency(cfg, features) + # TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features) + return policy diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 21b39a80..c4ca35b7 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -22,6 +22,8 @@ import numpy as np import torch from torch import nn +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.datasets.utils import build_dataset_frame from lerobot.processor import PolicyAction, RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STR @@ -198,3 +200,42 @@ def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict]) f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names) } return act_processed_policy + + +def raise_feature_mismatch_error( + provided_features: set[str], + expected_features: set[str], +) -> None: + """ + Raises a standardized ValueError for feature mismatches between dataset/environment and policy config. + """ + missing = expected_features - provided_features + extra = provided_features - expected_features + # TODO (jadechoghari): provide a dynamic rename map suggestion to the user. + raise ValueError( + f"Feature mismatch between dataset/environment and policy config.\n" + f"- Missing features: {sorted(missing) if missing else 'None'}\n" + f"- Extra features: {sorted(extra) if extra else 'None'}\n\n" + f"Please ensure your dataset and policy use consistent feature names.\n" + f"If your dataset uses different observation keys (e.g., cameras named differently), " + f"use the `--rename_map` argument, for example:\n" + f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", ' + f'"observation.images.top": "observation.images.camera2"}}\'' + ) + + +def validate_visual_features_consistency( + cfg: PreTrainedConfig, + features: dict[str, PolicyFeature], +) -> None: + """ + Validates visual feature consistency between a policy config and provided dataset/environment features. + + Args: + cfg (PreTrainedConfig): The model or policy configuration containing input_features and type. + features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects. + """ + expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL} + provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL} + if not provided_visuals.issubset(expected_visuals): + raise_feature_mismatch_error(provided_visuals, expected_visuals) diff --git a/tests/training/test_visual_validation.py b/tests/training/test_visual_validation.py new file mode 100644 index 00000000..af693fe5 --- /dev/null +++ b/tests/training/test_visual_validation.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Visual Feature Consistency Tests + +This module tests the `validate_visual_features_consistency` function, +which ensures that visual features (camera observations) in a dataset/env +match the expectations defined in a policy configuration. + +The purpose of this check is to prevent mismatches between what a policy expects +(e.g., `observation.images.camera1`, `camera2`, `camera3`) and what a dataset or +environment actually provides (e.g., `observation.images.top`, `side`, or fewer cameras). +""" + +from pathlib import Path + +import numpy as np +import pytest + +from lerobot.configs.default import DatasetConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.train import TrainPipelineConfig +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.policies.factory import make_policy_config +from lerobot.scripts.lerobot_train import train +from lerobot.utils.utils import auto_select_torch_device + +pytest.importorskip("transformers") + +DUMMY_REPO_ID = "dummy/repo" + + +@pytest.fixture +def temp_dir(tmp_path): + return tmp_path + + +DUMMY_STATE_DIM = 6 +DUMMY_ACTION_DIM = 6 +IMAGE_SIZE = 8 +DEVICE = auto_select_torch_device() + + +def make_dummy_dataset(camera_keys, tmp_path): + """Creates a minimal dummy dataset for testing rename_mapping logic.""" + features = { + "action": {"dtype": "float32", "shape": (DUMMY_ACTION_DIM,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (DUMMY_STATE_DIM,), "names": None}, + } + for cam in camera_keys: + features[f"observation.images.{cam}"] = { + "dtype": "image", + "shape": (IMAGE_SIZE, IMAGE_SIZE, 3), + "names": ["height", "width", "channel"], + } + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, + fps=30, + features=features, + root=tmp_path / "_dataset", + ) + root = tmp_path / "_dataset" + for ep_idx in range(2): + for _ in range(3): + frame = { + "action": np.random.randn(DUMMY_ACTION_DIM).astype(np.float32), + "observation.state": np.random.randn(DUMMY_STATE_DIM).astype(np.float32), + } + for cam in camera_keys: + frame[f"observation.images.{cam}"] = np.random.randint( + 0, 255, size=(IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8 + ) + frame["task"] = f"task_{ep_idx}" + dataset.add_frame(frame) + dataset.save_episode() + + dataset.finalize() + return dataset, root + + +def custom_validate(train_config: TrainPipelineConfig, policy_path: str, empty_cameras: int): + train_config.policy = PreTrainedConfig.from_pretrained(policy_path) + train_config.policy.pretrained_path = Path(policy_path) + # override empty_cameras and push_to_hub for testing + train_config.policy.empty_cameras = empty_cameras + train_config.policy.push_to_hub = False + if train_config.use_policy_training_preset: + train_config.optimizer = train_config.policy.get_optimizer_preset() + train_config.scheduler = train_config.policy.get_scheduler_preset() + return train_config + + +@pytest.mark.skip(reason="Skipping this test as it results OOM") +@pytest.mark.parametrize( + "camera_keys, empty_cameras, rename_map, expect_success", + [ + # case 1: dataset has fewer cameras than policy (3 instead of 4), but we specify empty_cameras=1 for smolvla, pi0, pi05 + (["camera1", "camera2", "camera3"], 1, {}, True), + # case 2: dataset has 2 cameras with different names, rename_mapping provided + ( + ["top", "side"], + 0, + { + "observation.images.top": "observation.images.camera1", + "observation.images.side": "observation.images.camera2", + }, + True, + ), + # case 3: dataset has 2 cameras, policy expects 3, names do not match, no empty_cameras + (["top", "side"], 0, {}, False), + # TODO: case 4: dataset has 2 cameras, policy expects 3, no rename_map, no empty_cameras, should raise for smolvla + # (["camera1", "camera2"], 0, {}, False), + ], +) +def test_train_with_camera_mismatch(camera_keys, empty_cameras, rename_map, expect_success, tmp_path): + """Tests that training works or fails depending on camera/feature alignment.""" + + _dataset, root = make_dummy_dataset(camera_keys, tmp_path) + pretrained_path = "lerobot/smolvla_base" + dataset_config = DatasetConfig(repo_id=DUMMY_REPO_ID, root=root) + policy_config = make_policy_config( + "smolvla", + optimizer_lr=0.01, + push_to_hub=False, + pretrained_path=pretrained_path, + device=DEVICE, + ) + policy_config.empty_cameras = empty_cameras + train_config = TrainPipelineConfig( + dataset=dataset_config, + policy=policy_config, + rename_map=rename_map, + output_dir=tmp_path / "_output", + steps=1, + ) + train_config = custom_validate(train_config, policy_path=pretrained_path, empty_cameras=empty_cameras) + # HACK: disable the internal CLI validation step for tests, we did it with custom_validate + train_config.validate = lambda: None + if expect_success: + train(train_config) + else: + with pytest.raises(ValueError): + train(train_config)