Skip normalization parameters in load_smolvla (#1274)
This commit is contained in:
@@ -157,9 +157,13 @@ def load_smolvla(
|
|||||||
|
|
||||||
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||||
|
|
||||||
missing, unexpected = model.load_state_dict(state_dict)
|
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
|
||||||
|
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
|
||||||
|
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
|
||||||
|
|
||||||
if missing or unexpected:
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"SmolVLA %d missing / %d unexpected keys",
|
"SmolVLA %d missing / %d unexpected keys",
|
||||||
len(missing),
|
len(missing),
|
||||||
|
|||||||
Reference in New Issue
Block a user