Skip normalization parameters in load_smolvla (#1274)
This commit is contained in:
@@ -157,9 +157,13 @@ def load_smolvla(
|
||||
|
||||
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||
|
||||
missing, unexpected = model.load_state_dict(state_dict)
|
||||
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
|
||||
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
|
||||
|
||||
if missing or unexpected:
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
|
||||
raise RuntimeError(
|
||||
"SmolVLA %d missing / %d unexpected keys",
|
||||
len(missing),
|
||||
|
||||
Reference in New Issue
Block a user