fix issues: checkpoints keys mismatch and 'task' tokenisation in smolvla (#1256)
Co-authored-by: danaaubakirova <d.aubakirova@alumni.edu.kz> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
This commit is contained in:
@@ -53,8 +53,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@@ -73,6 +76,98 @@ from lerobot.common.policies.utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import get_safe_dtype
|
from lerobot.common.utils.utils import get_safe_dtype
|
||||||
|
|
||||||
|
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||||
|
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||||
|
|
||||||
|
|
||||||
|
def canonicalise(k: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||||
|
normalisation-buffer key.
|
||||||
|
"""
|
||||||
|
return _VARIANT_RE.sub(".buffer_", k)
|
||||||
|
|
||||||
|
|
||||||
|
def standardise_state_dict(
|
||||||
|
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||||
|
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||||
|
"""
|
||||||
|
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
|
||||||
|
• If several variant keys collapse to the same canonical name we keep the
|
||||||
|
first one and log the collision.
|
||||||
|
• Returns the new dict + a list of entries that could not be matched.
|
||||||
|
"""
|
||||||
|
out, collisions, unmatched = {}, {}, []
|
||||||
|
|
||||||
|
for k, v in checkpoint.items():
|
||||||
|
canon = canonicalise(k)
|
||||||
|
if canon in ref_keys:
|
||||||
|
if canon in out: # duplicate after collapsing
|
||||||
|
collisions.setdefault(canon, []).append(k)
|
||||||
|
else:
|
||||||
|
out[canon] = v
|
||||||
|
else:
|
||||||
|
unmatched.append(k)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
for canon, variants in collisions.items():
|
||||||
|
print(f"[standardise_state_dict] '{canon}' ← {variants}")
|
||||||
|
if unmatched:
|
||||||
|
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
|
||||||
|
|
||||||
|
out.update({k: checkpoint[k] for k in unmatched})
|
||||||
|
return out, unmatched
|
||||||
|
|
||||||
|
|
||||||
|
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||||
|
"""
|
||||||
|
Renames keys in a checkpoint dictionary based on the given rename string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint (dict): The checkpoint dictionary.
|
||||||
|
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The modified checkpoint with renamed keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
|
||||||
|
|
||||||
|
new_checkpoint = {}
|
||||||
|
for k, v in checkpoint.items():
|
||||||
|
for old_key, new_key in rename_dict.items():
|
||||||
|
if old_key in k:
|
||||||
|
k = k.replace(old_key, new_key)
|
||||||
|
new_checkpoint[k] = v
|
||||||
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def load_smolvla(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
filename: str | os.PathLike,
|
||||||
|
*,
|
||||||
|
device: str = "cpu",
|
||||||
|
checkpoint_keys_mapping: str = "",
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
state_dict = safetensors.torch.load_file(filename, device=device)
|
||||||
|
|
||||||
|
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
|
||||||
|
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
|
||||||
|
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
|
||||||
|
|
||||||
|
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||||
|
|
||||||
|
missing, unexpected = model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
if missing or unexpected:
|
||||||
|
raise RuntimeError(
|
||||||
|
"SmolVLA %d missing / %d unexpected keys",
|
||||||
|
len(missing),
|
||||||
|
len(unexpected),
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def create_sinusoidal_pos_embedding(
|
def create_sinusoidal_pos_embedding(
|
||||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||||
@@ -264,6 +359,23 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
|
||||||
|
@classmethod
|
||||||
|
def _load_as_safetensor(
|
||||||
|
cls,
|
||||||
|
model: "SmolVLAPolicy",
|
||||||
|
model_file: str,
|
||||||
|
map_location: str,
|
||||||
|
strict: bool,
|
||||||
|
):
|
||||||
|
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||||
|
return load_smolvla(
|
||||||
|
model,
|
||||||
|
model_file,
|
||||||
|
device=map_location,
|
||||||
|
checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||||
|
)
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
@@ -387,10 +499,14 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
"""Tokenize the text input"""
|
"""Tokenize the text input"""
|
||||||
device = batch[OBS_STATE].device
|
device = batch[OBS_STATE].device
|
||||||
tasks = batch["task"]
|
tasks = batch["task"]
|
||||||
|
if isinstance(tasks, str):
|
||||||
|
tasks = [tasks]
|
||||||
|
|
||||||
if len(tasks) == 1:
|
if len(tasks) == 1:
|
||||||
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||||
|
|
||||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||||
|
|
||||||
tokenized_prompt = self.language_tokenizer.__call__(
|
tokenized_prompt = self.language_tokenizer.__call__(
|
||||||
tasks,
|
tasks,
|
||||||
padding=self.config.pad_language_to,
|
padding=self.config.pad_language_to,
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ intelrealsense = [
|
|||||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||||
]
|
]
|
||||||
pi0 = ["transformers>=4.48.0"]
|
pi0 = ["transformers>=4.48.0"]
|
||||||
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"]
|
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
|
||||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||||
stretch = [
|
stretch = [
|
||||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||||
|
|||||||
Reference in New Issue
Block a user