From d0521189b1a5e8b98d345fb9390896502e377621 Mon Sep 17 00:00:00 2001 From: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com> Date: Wed, 11 Jun 2025 16:56:55 +0200 Subject: [PATCH] fix issues: checkpoints keys mismatch and 'task' tokenisation in smolvla (#1256) Co-authored-by: danaaubakirova Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Simon Alibert --- .../policies/smolvla/modeling_smolvla.py | 116 ++++++++++++++++++ pyproject.toml | 2 +- 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py index 6ac2d3e7e..a6745880b 100644 --- a/lerobot/common/policies/smolvla/modeling_smolvla.py +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -53,8 +53,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") """ import math +import os +import re from collections import deque +import safetensors import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn @@ -73,6 +76,98 @@ from lerobot.common.policies.utils import ( ) from lerobot.common.utils.utils import get_safe_dtype +# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker +_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") + + +def canonicalise(k: str) -> str: + """ + Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a + normalisation-buffer key. + """ + return _VARIANT_RE.sub(".buffer_", k) + + +def standardise_state_dict( + checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True +) -> tuple[dict[str, torch.Tensor], list[str]]: + """ + • Re-keys `checkpoint ` so that every entry matches the *reference* key set. + • If several variant keys collapse to the same canonical name we keep the + first one and log the collision. + • Returns the new dict + a list of entries that could not be matched. + """ + out, collisions, unmatched = {}, {}, [] + + for k, v in checkpoint.items(): + canon = canonicalise(k) + if canon in ref_keys: + if canon in out: # duplicate after collapsing + collisions.setdefault(canon, []).append(k) + else: + out[canon] = v + else: + unmatched.append(k) + + if verbose: + for canon, variants in collisions.items(): + print(f"[standardise_state_dict] '{canon}' ← {variants}") + if unmatched: + print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") + + out.update({k: checkpoint[k] for k in unmatched}) + return out, unmatched + + +def rename_checkpoint_keys(checkpoint: dict, rename_str: str): + """ + Renames keys in a checkpoint dictionary based on the given rename string. + + Args: + checkpoint (dict): The checkpoint dictionary. + rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". + + Returns: + dict: The modified checkpoint with renamed keys. + """ + + rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) + + new_checkpoint = {} + for k, v in checkpoint.items(): + for old_key, new_key in rename_dict.items(): + if old_key in k: + k = k.replace(old_key, new_key) + new_checkpoint[k] = v + return new_checkpoint + + +def load_smolvla( + model: torch.nn.Module, + filename: str | os.PathLike, + *, + device: str = "cpu", + checkpoint_keys_mapping: str = "", +) -> torch.nn.Module: + state_dict = safetensors.torch.load_file(filename, device=device) + + # Optional user-supplied renames (e.g. "model._orig_mod.//model.") + if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: + state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) + + state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) + + missing, unexpected = model.load_state_dict(state_dict) + + if missing or unexpected: + raise RuntimeError( + "SmolVLA %d missing / %d unexpected keys", + len(missing), + len(unexpected), + ) + + return model + def create_sinusoidal_pos_embedding( time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" @@ -264,6 +359,23 @@ class SmolVLAPolicy(PreTrainedPolicy): ACTION: deque(maxlen=self.config.n_action_steps), } + # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues + @classmethod + def _load_as_safetensor( + cls, + model: "SmolVLAPolicy", + model_file: str, + map_location: str, + strict: bool, + ): + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) + return load_smolvla( + model, + model_file, + device=map_location, + checkpoint_keys_mapping="model._orig_mod.//model.", + ) + def get_optim_params(self) -> dict: return self.parameters() @@ -387,10 +499,14 @@ class SmolVLAPolicy(PreTrainedPolicy): """Tokenize the text input""" device = batch[OBS_STATE].device tasks = batch["task"] + if isinstance(tasks, str): + tasks = [tasks] + if len(tasks) == 1: tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] + tokenized_prompt = self.language_tokenizer.__call__( tasks, padding=self.config.pad_language_to, diff --git a/pyproject.toml b/pyproject.toml index 2ce5d049b..a99b1b16c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ intelrealsense = [ "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", ] pi0 = ["transformers>=4.48.0"] -smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"] +smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",