diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 2f3a65dbc..d17915c36 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -65,6 +65,8 @@ class TrainPipelineConfig(HubMixin): eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) checkpoint_path: Path | None = field(init=False, default=None) + # Rename map for the observation to override the image and state keys + rename_map: dict[str, str] = field(default_factory=dict) def validate(self) -> None: # HACK: We parse again the cli args here to get the pretrained paths if there was some. diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 9c67e317a..6e524f2ab 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -303,6 +303,7 @@ def make_policy( cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata | None = None, env_cfg: EnvConfig | None = None, + rename_map: dict[str, str] | None = None, ) -> PreTrainedPolicy: """ Instantiate a policy model. @@ -319,6 +320,8 @@ def make_policy( statistics for normalization layers. env_cfg: Environment configuration used to infer feature shapes and types. One of `ds_meta` or `env_cfg` must be provided. + rename_map: Optional mapping of dataset or environment feature keys to match + expected policy feature names (e.g., `"left"` → `"camera1"`). Returns: An instantiated and device-placed policy model. @@ -380,4 +383,21 @@ def make_policy( # policy = torch.compile(policy, mode="reduce-overhead") + if not rename_map: + expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys()) + provided_features = set(features.keys()) + if expected_features and provided_features != expected_features: + missing = expected_features - provided_features + extra = provided_features - expected_features + # TODO (jadechoghari): provide a dynamic rename map suggestion to the user. + raise ValueError( + f"Feature mismatch between dataset/environment and policy config.\n" + f"- Missing features: {sorted(missing) if missing else 'None'}\n" + f"- Extra features: {sorted(extra) if extra else 'None'}\n\n" + f"Please ensure your dataset and policy use consistent feature names.\n" + f"If your dataset uses different observation keys (e.g., cameras named differently), " + f"use the `--rename_map` argument, for example:\n" + f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", ' + f'"observation.images.top": "observation.images.camera2"}}\'' + ) return policy diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 8796e897e..754fd15fe 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -501,6 +501,7 @@ def eval_main(cfg: EvalPipelineConfig): policy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, + rename_map=cfg.rename_map, ) policy.eval() diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 84eb81ad4..0cc6e037f 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -203,6 +203,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): policy = make_policy( cfg=cfg.policy, ds_meta=dataset.meta, + rename_map=cfg.rename_map, ) # Wait for all processes to finish policy creation before continuing @@ -224,6 +225,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): "norm_map": policy.config.normalization_mapping, }, } + processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = { + "rename_map": cfg.rename_map + } postprocessor_kwargs["postprocessor_overrides"] = { "unnormalizer_processor": { "stats": dataset.meta.stats,