forked from tangger/lerobot
fix issues: checkpoints keys mismatch and 'task' tokenisation in smolvla (#1256)
Co-authored-by: danaaubakirova <d.aubakirova@alumni.edu.kz> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
This commit is contained in:
@@ -53,8 +53,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
@@ -73,6 +76,98 @@ from lerobot.common.policies.utils import (
|
||||
)
|
||||
from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||
|
||||
|
||||
def canonicalise(k: str) -> str:
|
||||
"""
|
||||
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||
normalisation-buffer key.
|
||||
"""
|
||||
return _VARIANT_RE.sub(".buffer_", k)
|
||||
|
||||
|
||||
def standardise_state_dict(
|
||||
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||
"""
|
||||
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
|
||||
• If several variant keys collapse to the same canonical name we keep the
|
||||
first one and log the collision.
|
||||
• Returns the new dict + a list of entries that could not be matched.
|
||||
"""
|
||||
out, collisions, unmatched = {}, {}, []
|
||||
|
||||
for k, v in checkpoint.items():
|
||||
canon = canonicalise(k)
|
||||
if canon in ref_keys:
|
||||
if canon in out: # duplicate after collapsing
|
||||
collisions.setdefault(canon, []).append(k)
|
||||
else:
|
||||
out[canon] = v
|
||||
else:
|
||||
unmatched.append(k)
|
||||
|
||||
if verbose:
|
||||
for canon, variants in collisions.items():
|
||||
print(f"[standardise_state_dict] '{canon}' ← {variants}")
|
||||
if unmatched:
|
||||
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
|
||||
|
||||
out.update({k: checkpoint[k] for k in unmatched})
|
||||
return out, unmatched
|
||||
|
||||
|
||||
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||
"""
|
||||
Renames keys in a checkpoint dictionary based on the given rename string.
|
||||
|
||||
Args:
|
||||
checkpoint (dict): The checkpoint dictionary.
|
||||
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
|
||||
|
||||
Returns:
|
||||
dict: The modified checkpoint with renamed keys.
|
||||
"""
|
||||
|
||||
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
|
||||
|
||||
new_checkpoint = {}
|
||||
for k, v in checkpoint.items():
|
||||
for old_key, new_key in rename_dict.items():
|
||||
if old_key in k:
|
||||
k = k.replace(old_key, new_key)
|
||||
new_checkpoint[k] = v
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def load_smolvla(
|
||||
model: torch.nn.Module,
|
||||
filename: str | os.PathLike,
|
||||
*,
|
||||
device: str = "cpu",
|
||||
checkpoint_keys_mapping: str = "",
|
||||
) -> torch.nn.Module:
|
||||
state_dict = safetensors.torch.load_file(filename, device=device)
|
||||
|
||||
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
|
||||
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
|
||||
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
|
||||
|
||||
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||
|
||||
missing, unexpected = model.load_state_dict(state_dict)
|
||||
|
||||
if missing or unexpected:
|
||||
raise RuntimeError(
|
||||
"SmolVLA %d missing / %d unexpected keys",
|
||||
len(missing),
|
||||
len(unexpected),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
@@ -264,6 +359,23 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls,
|
||||
model: "SmolVLAPolicy",
|
||||
model_file: str,
|
||||
map_location: str,
|
||||
strict: bool,
|
||||
):
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
return load_smolvla(
|
||||
model,
|
||||
model_file,
|
||||
device=map_location,
|
||||
checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||
)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -387,10 +499,14 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
if len(tasks) == 1:
|
||||
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding=self.config.pad_language_to,
|
||||
|
||||
@@ -90,7 +90,7 @@ intelrealsense = [
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"]
|
||||
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
stretch = [
|
||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||
|
||||
Reference in New Issue
Block a user