diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py index a6745880..5e0a9622 100644 --- a/lerobot/common/policies/smolvla/modeling_smolvla.py +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -157,9 +157,13 @@ def load_smolvla( state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) - missing, unexpected = model.load_state_dict(state_dict) + # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset + norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") + state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} - if missing or unexpected: + missing, unexpected = model.load_state_dict(state_dict, strict=False) + + if not all(key.startswith(norm_keys) for key in missing) or unexpected: raise RuntimeError( "SmolVLA %d missing / %d unexpected keys", len(missing),