fix(make_policy): rename mapping edge cases in training (#2332)
* fix bug * update fixes * add hf license * more fixes * add transformers * iterate on review * more fixes * more fixes * add a False test * reduce img size * reduce img size * skip the test * add * add style
This commit is contained in:
@@ -38,6 +38,7 @@ from lerobot.policies.sac.configuration_sac import SACConfig
|
|||||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
|
from lerobot.policies.utils import validate_visual_features_consistency
|
||||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
@@ -420,20 +421,7 @@ def make_policy(
|
|||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
||||||
if not rename_map:
|
if not rename_map:
|
||||||
expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys())
|
validate_visual_features_consistency(cfg, features)
|
||||||
provided_features = set(features.keys())
|
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
|
||||||
if expected_features and provided_features != expected_features:
|
|
||||||
missing = expected_features - provided_features
|
|
||||||
extra = provided_features - expected_features
|
|
||||||
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
|
|
||||||
raise ValueError(
|
|
||||||
f"Feature mismatch between dataset/environment and policy config.\n"
|
|
||||||
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
|
|
||||||
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
|
|
||||||
f"Please ensure your dataset and policy use consistent feature names.\n"
|
|
||||||
f"If your dataset uses different observation keys (e.g., cameras named differently), "
|
|
||||||
f"use the `--rename_map` argument, for example:\n"
|
|
||||||
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
|
|
||||||
f'"observation.images.top": "observation.images.camera2"}}\''
|
|
||||||
)
|
|
||||||
return policy
|
return policy
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.datasets.utils import build_dataset_frame
|
from lerobot.datasets.utils import build_dataset_frame
|
||||||
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
|
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
|
||||||
from lerobot.utils.constants import ACTION, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
@@ -198,3 +200,42 @@ def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict])
|
|||||||
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
|
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
|
||||||
}
|
}
|
||||||
return act_processed_policy
|
return act_processed_policy
|
||||||
|
|
||||||
|
|
||||||
|
def raise_feature_mismatch_error(
|
||||||
|
provided_features: set[str],
|
||||||
|
expected_features: set[str],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Raises a standardized ValueError for feature mismatches between dataset/environment and policy config.
|
||||||
|
"""
|
||||||
|
missing = expected_features - provided_features
|
||||||
|
extra = provided_features - expected_features
|
||||||
|
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
|
||||||
|
raise ValueError(
|
||||||
|
f"Feature mismatch between dataset/environment and policy config.\n"
|
||||||
|
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
|
||||||
|
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
|
||||||
|
f"Please ensure your dataset and policy use consistent feature names.\n"
|
||||||
|
f"If your dataset uses different observation keys (e.g., cameras named differently), "
|
||||||
|
f"use the `--rename_map` argument, for example:\n"
|
||||||
|
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
|
||||||
|
f'"observation.images.top": "observation.images.camera2"}}\''
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_visual_features_consistency(
|
||||||
|
cfg: PreTrainedConfig,
|
||||||
|
features: dict[str, PolicyFeature],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
||||||
|
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
|
||||||
|
"""
|
||||||
|
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
||||||
|
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
||||||
|
if not provided_visuals.issubset(expected_visuals):
|
||||||
|
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
||||||
|
|||||||
157
tests/training/test_visual_validation.py
Normal file
157
tests/training/test_visual_validation.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Visual Feature Consistency Tests
|
||||||
|
|
||||||
|
This module tests the `validate_visual_features_consistency` function,
|
||||||
|
which ensures that visual features (camera observations) in a dataset/env
|
||||||
|
match the expectations defined in a policy configuration.
|
||||||
|
|
||||||
|
The purpose of this check is to prevent mismatches between what a policy expects
|
||||||
|
(e.g., `observation.images.camera1`, `camera2`, `camera3`) and what a dataset or
|
||||||
|
environment actually provides (e.g., `observation.images.top`, `side`, or fewer cameras).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.configs.default import DatasetConfig
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.policies.factory import make_policy_config
|
||||||
|
from lerobot.scripts.lerobot_train import train
|
||||||
|
from lerobot.utils.utils import auto_select_torch_device
|
||||||
|
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
|
DUMMY_REPO_ID = "dummy/repo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir(tmp_path):
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
DUMMY_STATE_DIM = 6
|
||||||
|
DUMMY_ACTION_DIM = 6
|
||||||
|
IMAGE_SIZE = 8
|
||||||
|
DEVICE = auto_select_torch_device()
|
||||||
|
|
||||||
|
|
||||||
|
def make_dummy_dataset(camera_keys, tmp_path):
|
||||||
|
"""Creates a minimal dummy dataset for testing rename_mapping logic."""
|
||||||
|
features = {
|
||||||
|
"action": {"dtype": "float32", "shape": (DUMMY_ACTION_DIM,), "names": None},
|
||||||
|
"observation.state": {"dtype": "float32", "shape": (DUMMY_STATE_DIM,), "names": None},
|
||||||
|
}
|
||||||
|
for cam in camera_keys:
|
||||||
|
features[f"observation.images.{cam}"] = {
|
||||||
|
"dtype": "image",
|
||||||
|
"shape": (IMAGE_SIZE, IMAGE_SIZE, 3),
|
||||||
|
"names": ["height", "width", "channel"],
|
||||||
|
}
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id=DUMMY_REPO_ID,
|
||||||
|
fps=30,
|
||||||
|
features=features,
|
||||||
|
root=tmp_path / "_dataset",
|
||||||
|
)
|
||||||
|
root = tmp_path / "_dataset"
|
||||||
|
for ep_idx in range(2):
|
||||||
|
for _ in range(3):
|
||||||
|
frame = {
|
||||||
|
"action": np.random.randn(DUMMY_ACTION_DIM).astype(np.float32),
|
||||||
|
"observation.state": np.random.randn(DUMMY_STATE_DIM).astype(np.float32),
|
||||||
|
}
|
||||||
|
for cam in camera_keys:
|
||||||
|
frame[f"observation.images.{cam}"] = np.random.randint(
|
||||||
|
0, 255, size=(IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8
|
||||||
|
)
|
||||||
|
frame["task"] = f"task_{ep_idx}"
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
dataset.finalize()
|
||||||
|
return dataset, root
|
||||||
|
|
||||||
|
|
||||||
|
def custom_validate(train_config: TrainPipelineConfig, policy_path: str, empty_cameras: int):
|
||||||
|
train_config.policy = PreTrainedConfig.from_pretrained(policy_path)
|
||||||
|
train_config.policy.pretrained_path = Path(policy_path)
|
||||||
|
# override empty_cameras and push_to_hub for testing
|
||||||
|
train_config.policy.empty_cameras = empty_cameras
|
||||||
|
train_config.policy.push_to_hub = False
|
||||||
|
if train_config.use_policy_training_preset:
|
||||||
|
train_config.optimizer = train_config.policy.get_optimizer_preset()
|
||||||
|
train_config.scheduler = train_config.policy.get_scheduler_preset()
|
||||||
|
return train_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Skipping this test as it results OOM")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"camera_keys, empty_cameras, rename_map, expect_success",
|
||||||
|
[
|
||||||
|
# case 1: dataset has fewer cameras than policy (3 instead of 4), but we specify empty_cameras=1 for smolvla, pi0, pi05
|
||||||
|
(["camera1", "camera2", "camera3"], 1, {}, True),
|
||||||
|
# case 2: dataset has 2 cameras with different names, rename_mapping provided
|
||||||
|
(
|
||||||
|
["top", "side"],
|
||||||
|
0,
|
||||||
|
{
|
||||||
|
"observation.images.top": "observation.images.camera1",
|
||||||
|
"observation.images.side": "observation.images.camera2",
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
# case 3: dataset has 2 cameras, policy expects 3, names do not match, no empty_cameras
|
||||||
|
(["top", "side"], 0, {}, False),
|
||||||
|
# TODO: case 4: dataset has 2 cameras, policy expects 3, no rename_map, no empty_cameras, should raise for smolvla
|
||||||
|
# (["camera1", "camera2"], 0, {}, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_train_with_camera_mismatch(camera_keys, empty_cameras, rename_map, expect_success, tmp_path):
|
||||||
|
"""Tests that training works or fails depending on camera/feature alignment."""
|
||||||
|
|
||||||
|
_dataset, root = make_dummy_dataset(camera_keys, tmp_path)
|
||||||
|
pretrained_path = "lerobot/smolvla_base"
|
||||||
|
dataset_config = DatasetConfig(repo_id=DUMMY_REPO_ID, root=root)
|
||||||
|
policy_config = make_policy_config(
|
||||||
|
"smolvla",
|
||||||
|
optimizer_lr=0.01,
|
||||||
|
push_to_hub=False,
|
||||||
|
pretrained_path=pretrained_path,
|
||||||
|
device=DEVICE,
|
||||||
|
)
|
||||||
|
policy_config.empty_cameras = empty_cameras
|
||||||
|
train_config = TrainPipelineConfig(
|
||||||
|
dataset=dataset_config,
|
||||||
|
policy=policy_config,
|
||||||
|
rename_map=rename_map,
|
||||||
|
output_dir=tmp_path / "_output",
|
||||||
|
steps=1,
|
||||||
|
)
|
||||||
|
train_config = custom_validate(train_config, policy_path=pretrained_path, empty_cameras=empty_cameras)
|
||||||
|
# HACK: disable the internal CLI validation step for tests, we did it with custom_validate
|
||||||
|
train_config.validate = lambda: None
|
||||||
|
if expect_success:
|
||||||
|
train(train_config)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
train(train_config)
|
||||||
Reference in New Issue
Block a user