fix(bug): Fix policy renaming ValueError during training (#2278)

* fixes

* style

* Update src/lerobot/policies/factory.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* style

* add review fixes

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Jade Choghari
2025-10-21 16:00:46 +02:00
committed by GitHub
parent 63cd2111ad
commit a024d33750
4 changed files with 27 additions and 0 deletions

View File

@@ -65,6 +65,8 @@ class TrainPipelineConfig(HubMixin):
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
checkpoint_path: Path | None = field(init=False, default=None)
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.

View File

@@ -303,6 +303,7 @@ def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
rename_map: dict[str, str] | None = None,
) -> PreTrainedPolicy:
"""
Instantiate a policy model.
@@ -319,6 +320,8 @@ def make_policy(
statistics for normalization layers.
env_cfg: Environment configuration used to infer feature shapes and types.
One of `ds_meta` or `env_cfg` must be provided.
rename_map: Optional mapping of dataset or environment feature keys to match
expected policy feature names (e.g., `"left"` → `"camera1"`).
Returns:
An instantiated and device-placed policy model.
@@ -380,4 +383,21 @@ def make_policy(
# policy = torch.compile(policy, mode="reduce-overhead")
if not rename_map:
expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys())
provided_features = set(features.keys())
if expected_features and provided_features != expected_features:
missing = expected_features - provided_features
extra = provided_features - expected_features
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
raise ValueError(
f"Feature mismatch between dataset/environment and policy config.\n"
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
f"Please ensure your dataset and policy use consistent feature names.\n"
f"If your dataset uses different observation keys (e.g., cameras named differently), "
f"use the `--rename_map` argument, for example:\n"
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
f'"observation.images.top": "observation.images.camera2"}}\''
)
return policy

View File

@@ -501,6 +501,7 @@ def eval_main(cfg: EvalPipelineConfig):
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
rename_map=cfg.rename_map,
)
policy.eval()

View File

@@ -203,6 +203,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
rename_map=cfg.rename_map,
)
# Wait for all processes to finish policy creation before continuing
@@ -224,6 +225,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
"norm_map": policy.config.normalization_mapping,
},
}
processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = {
"rename_map": cfg.rename_map
}
postprocessor_kwargs["postprocessor_overrides"] = {
"unnormalizer_processor": {
"stats": dataset.meta.stats,