Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
183 lines
7.1 KiB
Python
183 lines
7.1 KiB
Python
import abc
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Type, TypeVar
|
|
|
|
import packaging
|
|
import safetensors
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
|
from huggingface_hub.errors import HfHubHTTPError
|
|
from safetensors.torch import load_model as load_model_as_safetensor
|
|
from safetensors.torch import save_model as save_model_as_safetensor
|
|
from torch import Tensor, nn
|
|
|
|
from lerobot.common.utils.hub import HubMixin
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
|
|
T = TypeVar("T", bound="PreTrainedPolicy")
|
|
|
|
DEFAULT_POLICY_CARD = """
|
|
---
|
|
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
|
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
|
{{ card_data }}
|
|
---
|
|
|
|
This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot):
|
|
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
|
"""
|
|
|
|
|
|
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|
"""
|
|
Base class for policy models.
|
|
"""
|
|
|
|
config_class: None
|
|
name: None
|
|
|
|
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
|
|
super().__init__()
|
|
if not isinstance(config, PreTrainedConfig):
|
|
raise ValueError(
|
|
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
|
"`PreTrainedConfig`. To create a model from a pretrained model use "
|
|
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
|
)
|
|
self.config = config
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
if not getattr(cls, "config_class", None):
|
|
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
|
if not getattr(cls, "name", None):
|
|
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
|
|
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
self.config._save_pretrained(save_directory)
|
|
model_to_save = self.module if hasattr(self, "module") else self
|
|
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls: Type[T],
|
|
pretrained_name_or_path: str | Path,
|
|
*,
|
|
config: PreTrainedConfig | None = None,
|
|
force_download: bool = False,
|
|
resume_download: bool | None = None,
|
|
proxies: dict | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
map_location: str = "cpu",
|
|
strict: bool = False,
|
|
**kwargs,
|
|
) -> T:
|
|
"""
|
|
The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are
|
|
deactivated). To train it, you should first set it back in training mode with `policy.train()`.
|
|
"""
|
|
if config is None:
|
|
config = PreTrainedConfig.from_pretrained(
|
|
pretrained_name_or_path=pretrained_name_or_path,
|
|
force_download=force_download,
|
|
resume_download=resume_download,
|
|
proxies=proxies,
|
|
token=token,
|
|
cache_dir=cache_dir,
|
|
local_files_only=local_files_only,
|
|
revision=revision,
|
|
**kwargs,
|
|
)
|
|
model_id = str(pretrained_name_or_path)
|
|
instance = cls(config, **kwargs)
|
|
if os.path.isdir(model_id):
|
|
print("Loading weights from local directory")
|
|
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
|
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
|
else:
|
|
try:
|
|
model_file = hf_hub_download(
|
|
repo_id=model_id,
|
|
filename=SAFETENSORS_SINGLE_FILE,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
token=token,
|
|
local_files_only=local_files_only,
|
|
)
|
|
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
|
except HfHubHTTPError as e:
|
|
raise FileNotFoundError(
|
|
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
|
) from e
|
|
|
|
policy.to(map_location)
|
|
policy.eval()
|
|
return policy
|
|
|
|
@classmethod
|
|
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
|
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
|
load_model_as_safetensor(model, model_file, strict=strict)
|
|
if map_location != "cpu":
|
|
logging.warning(
|
|
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
|
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
|
" This leads to a slower loading time."
|
|
" Please update safetensors to version 0.4.3 or above for improved performance."
|
|
)
|
|
model.to(map_location)
|
|
else:
|
|
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
|
return model
|
|
|
|
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
|
# card = ModelCard.from_template(
|
|
# card_data=self._hub_mixin_info.model_card_data,
|
|
# template_str=self._hub_mixin_info.model_card_template,
|
|
# repo_url=self._hub_mixin_info.repo_url,
|
|
# docs_url=self._hub_mixin_info.docs_url,
|
|
# **kwargs,
|
|
# )
|
|
# return card
|
|
|
|
@abc.abstractmethod
|
|
def get_optim_params(self) -> dict:
|
|
"""
|
|
Returns the policy-specific parameters dict to be passed on to the optimizer.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def reset(self):
|
|
"""To be called whenever the environment is reset.
|
|
|
|
Does things like clearing caches.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def forward(self, batch: dict[str, Tensor]) -> dict:
|
|
"""Run the batch through the model and compute the loss for training or validation.
|
|
|
|
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
|
other items should be logging-friendly, native Python types.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
"""Return one action to run in the environment (potentially in batch mode).
|
|
|
|
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
|
with caching.
|
|
"""
|
|
raise NotImplementedError
|