Skip normalization parameters in load_smolvla (#1274)

This commit is contained in:
Simon Alibert
2025-06-13 11:06:45 +02:00
committed by GitHub
parent edfebd522c
commit 5c87365cc1

View File

@@ -157,9 +157,13 @@ def load_smolvla(
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
missing, unexpected = model.load_state_dict(state_dict)
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
if missing or unexpected:
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
raise RuntimeError(
"SmolVLA %d missing / %d unexpected keys",
len(missing),