fix(make_policy): rename mapping edge cases in training (#2332)
* fix bug * update fixes * add hf license * more fixes * add transformers * iterate on review * more fixes * more fixes * add a False test * reduce img size * reduce img size * skip the test * add * add style
This commit is contained in:
@@ -38,6 +38,7 @@ from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import validate_visual_features_consistency
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
@@ -420,20 +421,7 @@ def make_policy(
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
if not rename_map:
|
||||
expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys())
|
||||
provided_features = set(features.keys())
|
||||
if expected_features and provided_features != expected_features:
|
||||
missing = expected_features - provided_features
|
||||
extra = provided_features - expected_features
|
||||
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
|
||||
raise ValueError(
|
||||
f"Feature mismatch between dataset/environment and policy config.\n"
|
||||
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
|
||||
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
|
||||
f"Please ensure your dataset and policy use consistent feature names.\n"
|
||||
f"If your dataset uses different observation keys (e.g., cameras named differently), "
|
||||
f"use the `--rename_map` argument, for example:\n"
|
||||
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
|
||||
f'"observation.images.top": "observation.images.camera2"}}\''
|
||||
)
|
||||
validate_visual_features_consistency(cfg, features)
|
||||
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -22,6 +22,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.utils import build_dataset_frame
|
||||
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
@@ -198,3 +200,42 @@ def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict])
|
||||
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
|
||||
}
|
||||
return act_processed_policy
|
||||
|
||||
|
||||
def raise_feature_mismatch_error(
|
||||
provided_features: set[str],
|
||||
expected_features: set[str],
|
||||
) -> None:
|
||||
"""
|
||||
Raises a standardized ValueError for feature mismatches between dataset/environment and policy config.
|
||||
"""
|
||||
missing = expected_features - provided_features
|
||||
extra = provided_features - expected_features
|
||||
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
|
||||
raise ValueError(
|
||||
f"Feature mismatch between dataset/environment and policy config.\n"
|
||||
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
|
||||
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
|
||||
f"Please ensure your dataset and policy use consistent feature names.\n"
|
||||
f"If your dataset uses different observation keys (e.g., cameras named differently), "
|
||||
f"use the `--rename_map` argument, for example:\n"
|
||||
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
|
||||
f'"observation.images.top": "observation.images.camera2"}}\''
|
||||
)
|
||||
|
||||
|
||||
def validate_visual_features_consistency(
|
||||
cfg: PreTrainedConfig,
|
||||
features: dict[str, PolicyFeature],
|
||||
) -> None:
|
||||
"""
|
||||
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
||||
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
|
||||
"""
|
||||
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
||||
if not provided_visuals.issubset(expected_visuals):
|
||||
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
||||
|
||||
157
tests/training/test_visual_validation.py
Normal file
157
tests/training/test_visual_validation.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Visual Feature Consistency Tests
|
||||
|
||||
This module tests the `validate_visual_features_consistency` function,
|
||||
which ensures that visual features (camera observations) in a dataset/env
|
||||
match the expectations defined in a policy configuration.
|
||||
|
||||
The purpose of this check is to prevent mismatches between what a policy expects
|
||||
(e.g., `observation.images.camera1`, `camera2`, `camera3`) and what a dataset or
|
||||
environment actually provides (e.g., `observation.images.top`, `side`, or fewer cameras).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.scripts.lerobot_train import train
|
||||
from lerobot.utils.utils import auto_select_torch_device
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
DUMMY_REPO_ID = "dummy/repo"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(tmp_path):
|
||||
return tmp_path
|
||||
|
||||
|
||||
DUMMY_STATE_DIM = 6
|
||||
DUMMY_ACTION_DIM = 6
|
||||
IMAGE_SIZE = 8
|
||||
DEVICE = auto_select_torch_device()
|
||||
|
||||
|
||||
def make_dummy_dataset(camera_keys, tmp_path):
|
||||
"""Creates a minimal dummy dataset for testing rename_mapping logic."""
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (DUMMY_ACTION_DIM,), "names": None},
|
||||
"observation.state": {"dtype": "float32", "shape": (DUMMY_STATE_DIM,), "names": None},
|
||||
}
|
||||
for cam in camera_keys:
|
||||
features[f"observation.images.{cam}"] = {
|
||||
"dtype": "image",
|
||||
"shape": (IMAGE_SIZE, IMAGE_SIZE, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
}
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "_dataset",
|
||||
)
|
||||
root = tmp_path / "_dataset"
|
||||
for ep_idx in range(2):
|
||||
for _ in range(3):
|
||||
frame = {
|
||||
"action": np.random.randn(DUMMY_ACTION_DIM).astype(np.float32),
|
||||
"observation.state": np.random.randn(DUMMY_STATE_DIM).astype(np.float32),
|
||||
}
|
||||
for cam in camera_keys:
|
||||
frame[f"observation.images.{cam}"] = np.random.randint(
|
||||
0, 255, size=(IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8
|
||||
)
|
||||
frame["task"] = f"task_{ep_idx}"
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
return dataset, root
|
||||
|
||||
|
||||
def custom_validate(train_config: TrainPipelineConfig, policy_path: str, empty_cameras: int):
|
||||
train_config.policy = PreTrainedConfig.from_pretrained(policy_path)
|
||||
train_config.policy.pretrained_path = Path(policy_path)
|
||||
# override empty_cameras and push_to_hub for testing
|
||||
train_config.policy.empty_cameras = empty_cameras
|
||||
train_config.policy.push_to_hub = False
|
||||
if train_config.use_policy_training_preset:
|
||||
train_config.optimizer = train_config.policy.get_optimizer_preset()
|
||||
train_config.scheduler = train_config.policy.get_scheduler_preset()
|
||||
return train_config
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test as it results OOM")
|
||||
@pytest.mark.parametrize(
|
||||
"camera_keys, empty_cameras, rename_map, expect_success",
|
||||
[
|
||||
# case 1: dataset has fewer cameras than policy (3 instead of 4), but we specify empty_cameras=1 for smolvla, pi0, pi05
|
||||
(["camera1", "camera2", "camera3"], 1, {}, True),
|
||||
# case 2: dataset has 2 cameras with different names, rename_mapping provided
|
||||
(
|
||||
["top", "side"],
|
||||
0,
|
||||
{
|
||||
"observation.images.top": "observation.images.camera1",
|
||||
"observation.images.side": "observation.images.camera2",
|
||||
},
|
||||
True,
|
||||
),
|
||||
# case 3: dataset has 2 cameras, policy expects 3, names do not match, no empty_cameras
|
||||
(["top", "side"], 0, {}, False),
|
||||
# TODO: case 4: dataset has 2 cameras, policy expects 3, no rename_map, no empty_cameras, should raise for smolvla
|
||||
# (["camera1", "camera2"], 0, {}, False),
|
||||
],
|
||||
)
|
||||
def test_train_with_camera_mismatch(camera_keys, empty_cameras, rename_map, expect_success, tmp_path):
|
||||
"""Tests that training works or fails depending on camera/feature alignment."""
|
||||
|
||||
_dataset, root = make_dummy_dataset(camera_keys, tmp_path)
|
||||
pretrained_path = "lerobot/smolvla_base"
|
||||
dataset_config = DatasetConfig(repo_id=DUMMY_REPO_ID, root=root)
|
||||
policy_config = make_policy_config(
|
||||
"smolvla",
|
||||
optimizer_lr=0.01,
|
||||
push_to_hub=False,
|
||||
pretrained_path=pretrained_path,
|
||||
device=DEVICE,
|
||||
)
|
||||
policy_config.empty_cameras = empty_cameras
|
||||
train_config = TrainPipelineConfig(
|
||||
dataset=dataset_config,
|
||||
policy=policy_config,
|
||||
rename_map=rename_map,
|
||||
output_dir=tmp_path / "_output",
|
||||
steps=1,
|
||||
)
|
||||
train_config = custom_validate(train_config, policy_path=pretrained_path, empty_cameras=empty_cameras)
|
||||
# HACK: disable the internal CLI validation step for tests, we did it with custom_validate
|
||||
train_config.validate = lambda: None
|
||||
if expect_success:
|
||||
train(train_config)
|
||||
else:
|
||||
with pytest.raises(ValueError):
|
||||
train(train_config)
|
||||
Reference in New Issue
Block a user