diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index a34aa34f..e56946ac 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -66,7 +66,8 @@ from lerobot.policies.pi0.paligemma_with_expert import ( PaliGemmaWithExpertModel, ) from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.utils import get_safe_dtype +from lerobot.policies.utils import log_model_loading_keys +from lerobot.utils.utils import get_safe_dtype, init_logging def create_sinusoidal_pos_embedding( @@ -252,6 +253,99 @@ class PI0Policy(PreTrainedPolicy): """This should be called whenever the environment is reset.""" self._action_queue = deque([], maxlen=self.config.n_action_steps) + @classmethod + def _transform_state_dict_keys(cls, state_dict: dict) -> dict: + """ + Transform state dict keys to match expected model structure. + + Transformations: + - model.paligemma_with_expert.paligemma.language_model.lm_head -> + model.paligemma_with_expert.paligemma.lm_head + - model.paligemma_with_expert.paligemma.language_model.model -> + model.paligemma_with_expert.paligemma.model.language_model + - model.paligemma_with_expert.paligemma.vision_tower -> + model.paligemma_with_expert.paligemma.model.vision_tower + - model.paligemma_with_expert.paligemma.multi_modal_projector -> + model.paligemma_with_expert.paligemma.model.multi_modal_projector + + Also handles tied weights between lm_head.weight and + embed_tokens.weight. + """ + import re + + transformed_dict = {} + + transformations = [ + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"), + ".paligemma_with_expert.paligemma.lm_head", + ), + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"), + ".paligemma_with_expert.paligemma.model.language_model", + ), + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"), + ".paligemma_with_expert.paligemma.model.vision_tower", + ), + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"), + ".paligemma_with_expert.paligemma.model.multi_modal_projector", + ), + ] + + for key, value in state_dict.items(): + new_key = key + for pattern, replacement in transformations: + new_key = pattern.sub(replacement, new_key) + transformed_dict[new_key] = value + + # Handle tied weights: lm_head.weight and embed_tokens.weight share memory + lm_head_key = None + embed_tokens_key = None + + for key in transformed_dict: + if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"): + lm_head_key = key + elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"): + embed_tokens_key = key + if lm_head_key and embed_tokens_key: + break + + if lm_head_key and not embed_tokens_key: + embed_tokens_key = lm_head_key.replace( + ".lm_head.weight", ".model.language_model.embed_tokens.weight" + ) + transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key] + elif embed_tokens_key and not lm_head_key: + lm_head_key = embed_tokens_key.replace( + ".model.language_model.embed_tokens.weight", ".lm_head.weight" + ) + transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key] + + return transformed_dict + + @classmethod + def _load_as_safetensor( + cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool + ) -> "PI0Policy": + """Override to apply key transformations before loading.""" + from safetensors.torch import load_file + + init_logging() + # Load the state dict from file safely + state_dict = load_file(model_file, device=map_location) + + # Apply key transformations + transformed_state_dict = cls._transform_state_dict_keys(state_dict) + + # Load the transformed state dict + msg = model.load_state_dict(transformed_state_dict, strict=strict) + + # Log message + log_model_loading_keys(msg.missing_keys, msg.unexpected_keys) + return model + def get_optim_params(self) -> dict: return self.parameters() diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index d745c901..2f69309c 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -30,6 +30,7 @@ from torch import Tensor, nn from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig +from lerobot.policies.utils import log_model_loading_keys from lerobot.utils.hub import HubMixin T = TypeVar("T", bound="PreTrainedPolicy") @@ -128,18 +129,26 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): @classmethod def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: - if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): - load_model_as_safetensor(model, model_file, strict=strict) - if map_location != "cpu": - logging.warning( - "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." - " This means that the model is loaded on 'cpu' first and then copied to the device." - " This leads to a slower loading time." - " Please update safetensors to version 0.4.3 or above for improved performance." - ) - model.to(map_location) - else: - safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) + # Create base kwargs + kwargs = {"strict": strict} + + # Add device parameter for newer versions that support it + if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"): + kwargs["device"] = map_location + + # Load the model with appropriate kwargs + missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs) + log_model_loading_keys(missing_keys, unexpected_keys) + + # For older versions, manually move to device if needed + if "device" not in kwargs and map_location != "cpu": + logging.warning( + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." + " This means that the model is loaded on 'cpu' first and then copied to the device." + " This leads to a slower loading time." + " Please update safetensors to version 0.4.3 or above for improved performance." + ) + model.to(map_location) return model @abc.abstractmethod diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 5659e872..5a3994cd 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections import deque import torch @@ -71,3 +72,16 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple: with torch.inference_mode(): output = module(dummy_input) return tuple(output.shape) + + +def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None: + """Log missing and unexpected keys when loading a model. + + Args: + missing_keys (list[str]): Keys that were expected but not found. + unexpected_keys (list[str]): Keys that were found but not expected. + """ + if missing_keys: + logging.warning(f"Missing key(s) when loading model: {missing_keys}") + if unexpected_keys: + logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")