Use PytorchModelHubMixin to save models as safetensors (#125)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare
2024-05-01 16:17:18 +01:00
committed by GitHub
parent 01d5490d44
commit a4891095e4
18 changed files with 556 additions and 527 deletions

View File

@@ -46,7 +46,7 @@ def test_examples_3_and_2():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.pt", "config.yaml"]:
for file_name in ["model.safetensors", "config.json", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
@@ -58,15 +58,15 @@ def test_examples_3_and_2():
file_contents = _find_and_replace(
file_contents,
[
('pretrained_policy_name = "lerobot/diffusion_policy_pusht_image"', ""),
("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
('"eval.batch_size=10"', '"eval.batch_size=1"'),
('"device=cuda"', '"device=cpu"'),
(
'# folder = Path("outputs/train/example_pusht_diffusion")',
'folder = Path("outputs/train/example_pusht_diffusion")',
),
('hub_id = "lerobot/diffusion_policy_pusht_image"', ""),
("folder = Path(snapshot_download(hub_id)", ""),
],
)

View File

@@ -1,19 +1,33 @@
import inspect
import pytest
import torch
from huggingface_hub import PyTorchModelHubMixin
from lerobot import available_policies
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
@pytest.mark.parametrize("policy_name", available_policies)
def test_get_policy_and_config_classes(policy_name: str):
"""Check that the correct policy and config classes are returned."""
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
assert policy_cls.name == policy_name
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
# TODO(aliberts): refactor using lerobot/__init__.py variables
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
@@ -44,7 +58,8 @@ def test_policy(env_name, policy_name, extra_overrides):
"""
Tests:
- Making the policy object.
- Checking that the policy follows the correct protocol.
- Checking that the policy follows the correct protocol and subclasses nn.Module
and PyTorchModelHubMixin.
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
@@ -61,11 +76,13 @@ def test_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object.
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
assert isinstance(policy, torch.nn.Module)
assert isinstance(policy, PyTorchModelHubMixin)
# Check that we run select_actions and get the appropriate output.
env = make_env(cfg, num_parallel_envs=2)
@@ -108,29 +125,33 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test step through policy
env.step(action)
# Test load state_dict
if policy_name != "tdmpc":
# TODO(rcadene, alexander-soare): make it work for tdmpc
new_policy = make_policy(cfg)
new_policy.load_state_dict(policy.state_dict())
@pytest.mark.parametrize("policy_name", available_policies)
def test_policy_defaults(policy_name: str):
"""Check that the policy can be instantiated with defaults."""
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, _ = get_policy_and_config_classes(policy_name)
policy_cls()
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ACTPolicy])
def test_policy_defaults(policy_cls):
kwargs = {}
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.
if policy_cls is DiffusionPolicy:
kwargs = {"lr_scheduler_num_training_steps": 1}
policy_cls(**kwargs)
@pytest.mark.parametrize("policy_name", available_policies)
def test_save_and_load_pretrained(policy_name: str):
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, _ = get_policy_and_config_classes(policy_name)
policy: Policy = policy_cls()
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
policy_ = policy_cls.from_pretrained(save_dir)
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
@pytest.mark.parametrize(
"insert_temporal_dim",
[
False,
True,
],
)
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
def test_normalize(insert_temporal_dim):
"""
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise