From 5c87365cc160617c45dc5d1bbb3788de010271a7 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Fri, 13 Jun 2025 11:06:45 +0200 Subject: [PATCH] Skip normalization parameters in load_smolvla (#1274) --- lerobot/common/policies/smolvla/modeling_smolvla.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py index a6745880b..5e0a9622e 100644 --- a/lerobot/common/policies/smolvla/modeling_smolvla.py +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -157,9 +157,13 @@ def load_smolvla( state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) - missing, unexpected = model.load_state_dict(state_dict) + # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset + norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") + state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} - if missing or unexpected: + missing, unexpected = model.load_state_dict(state_dict, strict=False) + + if not all(key.startswith(norm_keys) for key in missing) or unexpected: raise RuntimeError( "SmolVLA %d missing / %d unexpected keys", len(missing),