Fix pi0 checkpoint state map (#1415)

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Yushun Xiang
2025-07-30 23:38:32 +08:00
committed by GitHub
parent 67196c9d53
commit 71eff183ff
3 changed files with 130 additions and 13 deletions

View File

@@ -66,7 +66,8 @@ from lerobot.policies.pi0.paligemma_with_expert import (
PaliGemmaWithExpertModel, PaliGemmaWithExpertModel,
) )
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.utils import get_safe_dtype from lerobot.policies.utils import log_model_loading_keys
from lerobot.utils.utils import get_safe_dtype, init_logging
def create_sinusoidal_pos_embedding( def create_sinusoidal_pos_embedding(
@@ -252,6 +253,99 @@ class PI0Policy(PreTrainedPolicy):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps) self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
"""
Transform state dict keys to match expected model structure.
Transformations:
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
model.paligemma_with_expert.paligemma.lm_head
- model.paligemma_with_expert.paligemma.language_model.model ->
model.paligemma_with_expert.paligemma.model.language_model
- model.paligemma_with_expert.paligemma.vision_tower ->
model.paligemma_with_expert.paligemma.model.vision_tower
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
model.paligemma_with_expert.paligemma.model.multi_modal_projector
Also handles tied weights between lm_head.weight and
embed_tokens.weight.
"""
import re
transformed_dict = {}
transformations = [
(
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
".paligemma_with_expert.paligemma.lm_head",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
".paligemma_with_expert.paligemma.model.language_model",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
".paligemma_with_expert.paligemma.model.vision_tower",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
".paligemma_with_expert.paligemma.model.multi_modal_projector",
),
]
for key, value in state_dict.items():
new_key = key
for pattern, replacement in transformations:
new_key = pattern.sub(replacement, new_key)
transformed_dict[new_key] = value
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
lm_head_key = None
embed_tokens_key = None
for key in transformed_dict:
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
lm_head_key = key
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
embed_tokens_key = key
if lm_head_key and embed_tokens_key:
break
if lm_head_key and not embed_tokens_key:
embed_tokens_key = lm_head_key.replace(
".lm_head.weight", ".model.language_model.embed_tokens.weight"
)
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
elif embed_tokens_key and not lm_head_key:
lm_head_key = embed_tokens_key.replace(
".model.language_model.embed_tokens.weight", ".lm_head.weight"
)
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
return transformed_dict
@classmethod
def _load_as_safetensor(
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
) -> "PI0Policy":
"""Override to apply key transformations before loading."""
from safetensors.torch import load_file
init_logging()
# Load the state dict from file safely
state_dict = load_file(model_file, device=map_location)
# Apply key transformations
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
# Load the transformed state dict
msg = model.load_state_dict(transformed_state_dict, strict=strict)
# Log message
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
return model
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
return self.parameters() return self.parameters()

View File

@@ -30,6 +30,7 @@ from torch import Tensor, nn
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from lerobot.policies.utils import log_model_loading_keys
from lerobot.utils.hub import HubMixin from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="PreTrainedPolicy") T = TypeVar("T", bound="PreTrainedPolicy")
@@ -128,9 +129,19 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
@classmethod @classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # Create base kwargs
load_model_as_safetensor(model, model_file, strict=strict) kwargs = {"strict": strict}
if map_location != "cpu":
# Add device parameter for newer versions that support it
if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
kwargs["device"] = map_location
# Load the model with appropriate kwargs
missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
log_model_loading_keys(missing_keys, unexpected_keys)
# For older versions, manually move to device if needed
if "device" not in kwargs and map_location != "cpu":
logging.warning( logging.warning(
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
" This means that the model is loaded on 'cpu' first and then copied to the device." " This means that the model is loaded on 'cpu' first and then copied to the device."
@@ -138,8 +149,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
" Please update safetensors to version 0.4.3 or above for improved performance." " Please update safetensors to version 0.4.3 or above for improved performance."
) )
model.to(map_location) model.to(map_location)
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return model return model
@abc.abstractmethod @abc.abstractmethod

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from collections import deque from collections import deque
import torch import torch
@@ -71,3 +72,16 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
with torch.inference_mode(): with torch.inference_mode():
output = module(dummy_input) output = module(dummy_input)
return tuple(output.shape) return tuple(output.shape)
def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None:
"""Log missing and unexpected keys when loading a model.
Args:
missing_keys (list[str]): Keys that were expected but not found.
unexpected_keys (list[str]): Keys that were found but not expected.
"""
if missing_keys:
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")