forked from tangger/lerobot
Fix pi0 checkpoint state map (#1415)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -66,7 +66,8 @@ from lerobot.policies.pi0.paligemma_with_expert import (
|
|||||||
PaliGemmaWithExpertModel,
|
PaliGemmaWithExpertModel,
|
||||||
)
|
)
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.utils.utils import get_safe_dtype
|
from lerobot.policies.utils import log_model_loading_keys
|
||||||
|
from lerobot.utils.utils import get_safe_dtype, init_logging
|
||||||
|
|
||||||
|
|
||||||
def create_sinusoidal_pos_embedding(
|
def create_sinusoidal_pos_embedding(
|
||||||
@@ -252,6 +253,99 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Transform state dict keys to match expected model structure.
|
||||||
|
|
||||||
|
Transformations:
|
||||||
|
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
|
||||||
|
model.paligemma_with_expert.paligemma.lm_head
|
||||||
|
- model.paligemma_with_expert.paligemma.language_model.model ->
|
||||||
|
model.paligemma_with_expert.paligemma.model.language_model
|
||||||
|
- model.paligemma_with_expert.paligemma.vision_tower ->
|
||||||
|
model.paligemma_with_expert.paligemma.model.vision_tower
|
||||||
|
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
|
||||||
|
model.paligemma_with_expert.paligemma.model.multi_modal_projector
|
||||||
|
|
||||||
|
Also handles tied weights between lm_head.weight and
|
||||||
|
embed_tokens.weight.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
transformed_dict = {}
|
||||||
|
|
||||||
|
transformations = [
|
||||||
|
(
|
||||||
|
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
|
||||||
|
".paligemma_with_expert.paligemma.lm_head",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
|
||||||
|
".paligemma_with_expert.paligemma.model.language_model",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
|
||||||
|
".paligemma_with_expert.paligemma.model.vision_tower",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
|
||||||
|
".paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
new_key = key
|
||||||
|
for pattern, replacement in transformations:
|
||||||
|
new_key = pattern.sub(replacement, new_key)
|
||||||
|
transformed_dict[new_key] = value
|
||||||
|
|
||||||
|
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
|
||||||
|
lm_head_key = None
|
||||||
|
embed_tokens_key = None
|
||||||
|
|
||||||
|
for key in transformed_dict:
|
||||||
|
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
|
||||||
|
lm_head_key = key
|
||||||
|
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
|
||||||
|
embed_tokens_key = key
|
||||||
|
if lm_head_key and embed_tokens_key:
|
||||||
|
break
|
||||||
|
|
||||||
|
if lm_head_key and not embed_tokens_key:
|
||||||
|
embed_tokens_key = lm_head_key.replace(
|
||||||
|
".lm_head.weight", ".model.language_model.embed_tokens.weight"
|
||||||
|
)
|
||||||
|
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
|
||||||
|
elif embed_tokens_key and not lm_head_key:
|
||||||
|
lm_head_key = embed_tokens_key.replace(
|
||||||
|
".model.language_model.embed_tokens.weight", ".lm_head.weight"
|
||||||
|
)
|
||||||
|
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
|
||||||
|
|
||||||
|
return transformed_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_as_safetensor(
|
||||||
|
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
|
||||||
|
) -> "PI0Policy":
|
||||||
|
"""Override to apply key transformations before loading."""
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
# Load the state dict from file safely
|
||||||
|
state_dict = load_file(model_file, device=map_location)
|
||||||
|
|
||||||
|
# Apply key transformations
|
||||||
|
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
|
||||||
|
|
||||||
|
# Load the transformed state dict
|
||||||
|
msg = model.load_state_dict(transformed_state_dict, strict=strict)
|
||||||
|
|
||||||
|
# Log message
|
||||||
|
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
|
||||||
|
return model
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
from lerobot.policies.utils import log_model_loading_keys
|
||||||
from lerobot.utils.hub import HubMixin
|
from lerobot.utils.hub import HubMixin
|
||||||
|
|
||||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||||
@@ -128,18 +129,26 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
# Create base kwargs
|
||||||
load_model_as_safetensor(model, model_file, strict=strict)
|
kwargs = {"strict": strict}
|
||||||
if map_location != "cpu":
|
|
||||||
logging.warning(
|
# Add device parameter for newer versions that support it
|
||||||
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
|
||||||
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
kwargs["device"] = map_location
|
||||||
" This leads to a slower loading time."
|
|
||||||
" Please update safetensors to version 0.4.3 or above for improved performance."
|
# Load the model with appropriate kwargs
|
||||||
)
|
missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
|
||||||
model.to(map_location)
|
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||||
else:
|
|
||||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
# For older versions, manually move to device if needed
|
||||||
|
if "device" not in kwargs and map_location != "cpu":
|
||||||
|
logging.warning(
|
||||||
|
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
||||||
|
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
||||||
|
" This leads to a slower loading time."
|
||||||
|
" Please update safetensors to version 0.4.3 or above for improved performance."
|
||||||
|
)
|
||||||
|
model.to(map_location)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -71,3 +72,16 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
output = module(dummy_input)
|
output = module(dummy_input)
|
||||||
return tuple(output.shape)
|
return tuple(output.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None:
|
||||||
|
"""Log missing and unexpected keys when loading a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
missing_keys (list[str]): Keys that were expected but not found.
|
||||||
|
unexpected_keys (list[str]): Keys that were found but not expected.
|
||||||
|
"""
|
||||||
|
if missing_keys:
|
||||||
|
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
|
||||||
|
if unexpected_keys:
|
||||||
|
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
|
||||||
|
|||||||
Reference in New Issue
Block a user