Feat: Improve hub integration (#1382)

* feat(policies): Initial setup to push policies to hub with tags and model card

* feat: add dataset that is used to train

* Add model template summary

* fix: Update link model_card template

* fix: remove print

* fix: change import name

* fix: add model summary in template

* fix: minor text

* fix: comments Lucain

* fix: feedback steven

* fix: restructure push to hub

* fix: remove unneeded changes

* fix: import

* fix: import 2

* Add MANIFEST.in

* fix: feedback pr

* Fix tests

* tests: Add smolvla end-to-end test

* Fix: smolvla test

* fix test name

* fix policy tests

* Add push to hub false policy tests

* Do push to hub cleaner

* fix(ci): add push_to_hub false in tests

---------

Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
Pepijn
2025-06-26 14:36:16 +02:00
committed by GitHub
parent a989c79558
commit 0b2285d1ec
13 changed files with 206 additions and 101 deletions

View File

@@ -14,12 +14,14 @@
import abc
import logging
import os
from importlib.resources import files
from pathlib import Path
from typing import Type, TypeVar
from tempfile import TemporaryDirectory
from typing import List, Type, TypeVar
import packaging
import safetensors
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor
@@ -28,20 +30,10 @@ from torch import Tensor, nn
from lerobot.common.utils.hub import HubMixin
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
T = TypeVar("T", bound="PreTrainedPolicy")
DEFAULT_POLICY_CARD = """
---
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/model-cards
{{ card_data }}
---
This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot):
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
"""
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
@@ -150,16 +142,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return model
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
# card = ModelCard.from_template(
# card_data=self._hub_mixin_info.model_card_data,
# template_str=self._hub_mixin_info.model_card_template,
# repo_url=self._hub_mixin_info.repo_url,
# docs_url=self._hub_mixin_info.docs_url,
# **kwargs,
# )
# return card
@abc.abstractmethod
def get_optim_params(self) -> dict:
"""
@@ -197,3 +179,56 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
with caching.
"""
raise NotImplementedError
def push_model_to_hub(
self,
cfg: TrainPipelineConfig,
):
api = HfApi()
repo_id = api.create_repo(
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
).repo_id
# Push the files to the repo in a single commit
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
)
card.save(str(saved_path / "README.md"))
cfg.save_pretrained(saved_path) # Calls _save_pretrained and stores train config
commit_info = api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=saved_path,
commit_message="Upload policy weights, train config and readme",
allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.md"],
ignore_patterns=["*.tmp", "*.log"],
)
logging.info(f"Model pushed to {commit_info.repo_url.url}")
def generate_model_card(
self, dataset_repo_id: str, model_type: str, license: str | None, tags: List[str] | None
) -> ModelCard:
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
card_data = ModelCardData(
license=license or "apache-2.0",
library_name="lerobot",
pipeline_tag="robotics",
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
model_name=model_type,
datasets=dataset_repo_id,
base_model=base_model,
)
template_card = files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text()
card = ModelCard.from_template(card_data, template_str=template_card)
card.validate()
return card