Feat: Improve hub integration (#1382)
* feat(policies): Initial setup to push policies to hub with tags and model card * feat: add dataset that is used to train * Add model template summary * fix: Update link model_card template * fix: remove print * fix: change import name * fix: add model summary in template * fix: minor text * fix: comments Lucain * fix: feedback steven * fix: restructure push to hub * fix: remove unneeded changes * fix: import * fix: import 2 * Add MANIFEST.in * fix: feedback pr * Fix tests * tests: Add smolvla end-to-end test * Fix: smolvla test * fix test name * fix policy tests * Add push to hub false policy tests * Do push to hub cleaner * fix(ci): add push_to_hub false in tests --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -32,7 +32,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
||||
)
|
||||
train_cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
|
||||
@@ -338,8 +338,9 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
|
||||
env=make_env_config(env_name),
|
||||
policy=make_policy_config(policy_name),
|
||||
policy=make_policy_config(policy_name, push_to_hub=False),
|
||||
)
|
||||
cfg.validate()
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
|
||||
@@ -142,9 +142,10 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
||||
env=make_env_config(env_name, **env_kwargs),
|
||||
)
|
||||
train_cfg.validate()
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(train_cfg)
|
||||
@@ -213,7 +214,7 @@ def test_act_backbone_lr():
|
||||
cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
|
||||
)
|
||||
cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
@@ -415,6 +416,7 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
https://github.com/huggingface/lerobot/pull/1127.
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
|
||||
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
|
||||
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
|
||||
|
||||
Reference in New Issue
Block a user