forked from tangger/lerobot
fix(bug): Fix policy renaming ValueError during training (#2278)
* fixes * style * Update src/lerobot/policies/factory.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jade Choghari <chogharijade@gmail.com> * style * add review fixes --------- Signed-off-by: Jade Choghari <chogharijade@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -65,6 +65,8 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||||
checkpoint_path: Path | None = field(init=False, default=None)
|
checkpoint_path: Path | None = field(init=False, default=None)
|
||||||
|
# Rename map for the observation to override the image and state keys
|
||||||
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||||
|
|||||||
@@ -303,6 +303,7 @@ def make_policy(
|
|||||||
cfg: PreTrainedConfig,
|
cfg: PreTrainedConfig,
|
||||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||||
env_cfg: EnvConfig | None = None,
|
env_cfg: EnvConfig | None = None,
|
||||||
|
rename_map: dict[str, str] | None = None,
|
||||||
) -> PreTrainedPolicy:
|
) -> PreTrainedPolicy:
|
||||||
"""
|
"""
|
||||||
Instantiate a policy model.
|
Instantiate a policy model.
|
||||||
@@ -319,6 +320,8 @@ def make_policy(
|
|||||||
statistics for normalization layers.
|
statistics for normalization layers.
|
||||||
env_cfg: Environment configuration used to infer feature shapes and types.
|
env_cfg: Environment configuration used to infer feature shapes and types.
|
||||||
One of `ds_meta` or `env_cfg` must be provided.
|
One of `ds_meta` or `env_cfg` must be provided.
|
||||||
|
rename_map: Optional mapping of dataset or environment feature keys to match
|
||||||
|
expected policy feature names (e.g., `"left"` → `"camera1"`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instantiated and device-placed policy model.
|
An instantiated and device-placed policy model.
|
||||||
@@ -380,4 +383,21 @@ def make_policy(
|
|||||||
|
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
||||||
|
if not rename_map:
|
||||||
|
expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys())
|
||||||
|
provided_features = set(features.keys())
|
||||||
|
if expected_features and provided_features != expected_features:
|
||||||
|
missing = expected_features - provided_features
|
||||||
|
extra = provided_features - expected_features
|
||||||
|
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
|
||||||
|
raise ValueError(
|
||||||
|
f"Feature mismatch between dataset/environment and policy config.\n"
|
||||||
|
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
|
||||||
|
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
|
||||||
|
f"Please ensure your dataset and policy use consistent feature names.\n"
|
||||||
|
f"If your dataset uses different observation keys (e.g., cameras named differently), "
|
||||||
|
f"use the `--rename_map` argument, for example:\n"
|
||||||
|
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
|
||||||
|
f'"observation.images.top": "observation.images.camera2"}}\''
|
||||||
|
)
|
||||||
return policy
|
return policy
|
||||||
|
|||||||
@@ -501,6 +501,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
|
rename_map=cfg.rename_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|||||||
@@ -203,6 +203,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
|
rename_map=cfg.rename_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for all processes to finish policy creation before continuing
|
# Wait for all processes to finish policy creation before continuing
|
||||||
@@ -224,6 +225,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
"norm_map": policy.config.normalization_mapping,
|
"norm_map": policy.config.normalization_mapping,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = {
|
||||||
|
"rename_map": cfg.rename_map
|
||||||
|
}
|
||||||
postprocessor_kwargs["postprocessor_overrides"] = {
|
postprocessor_kwargs["postprocessor_overrides"] = {
|
||||||
"unnormalizer_processor": {
|
"unnormalizer_processor": {
|
||||||
"stats": dataset.meta.stats,
|
"stats": dataset.meta.stats,
|
||||||
|
|||||||
Reference in New Issue
Block a user