forked from tangger/lerobot
Use PytorchModelHubMixin to save models as safetensors (#125)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
@@ -46,7 +46,7 @@ def test_examples_3_and_2():
|
||||
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
||||
exec(file_contents, {})
|
||||
|
||||
for file_name in ["model.pt", "config.yaml"]:
|
||||
for file_name in ["model.safetensors", "config.json", "config.yaml"]:
|
||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||
|
||||
path = "examples/2_evaluate_pretrained_policy.py"
|
||||
@@ -58,15 +58,15 @@ def test_examples_3_and_2():
|
||||
file_contents = _find_and_replace(
|
||||
file_contents,
|
||||
[
|
||||
('pretrained_policy_name = "lerobot/diffusion_policy_pusht_image"', ""),
|
||||
("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""),
|
||||
(
|
||||
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
),
|
||||
('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
|
||||
('"eval.batch_size=10"', '"eval.batch_size=1"'),
|
||||
('"device=cuda"', '"device=cpu"'),
|
||||
(
|
||||
'# folder = Path("outputs/train/example_pusht_diffusion")',
|
||||
'folder = Path("outputs/train/example_pusht_diffusion")',
|
||||
),
|
||||
('hub_id = "lerobot/diffusion_policy_pusht_image"', ""),
|
||||
("folder = Path(snapshot_download(hub_id)", ""),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -1,19 +1,33 @@
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
from lerobot import available_policies
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_get_policy_and_config_classes(policy_name: str):
|
||||
"""Check that the correct policy and config classes are returned."""
|
||||
if policy_name == "tdmpc":
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_policy_and_config_classes(policy_name)
|
||||
return
|
||||
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
|
||||
assert policy_cls.name == policy_name
|
||||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
@@ -44,7 +58,8 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
"""
|
||||
Tests:
|
||||
- Making the policy object.
|
||||
- Checking that the policy follows the correct protocol.
|
||||
- Checking that the policy follows the correct protocol and subclasses nn.Module
|
||||
and PyTorchModelHubMixin.
|
||||
- Updating the policy.
|
||||
- Using the policy to select actions at inference time.
|
||||
- Test the action can be applied to the policy
|
||||
@@ -61,11 +76,13 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
||||
assert isinstance(policy, torch.nn.Module)
|
||||
assert isinstance(policy, PyTorchModelHubMixin)
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
env = make_env(cfg, num_parallel_envs=2)
|
||||
@@ -108,29 +125,33 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
# Test load state_dict
|
||||
if policy_name != "tdmpc":
|
||||
# TODO(rcadene, alexander-soare): make it work for tdmpc
|
||||
new_policy = make_policy(cfg)
|
||||
new_policy.load_state_dict(policy.state_dict())
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_policy_defaults(policy_name: str):
|
||||
"""Check that the policy can be instantiated with defaults."""
|
||||
if policy_name == "tdmpc":
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_policy_and_config_classes(policy_name)
|
||||
return
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy_cls()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ACTPolicy])
|
||||
def test_policy_defaults(policy_cls):
|
||||
kwargs = {}
|
||||
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.
|
||||
if policy_cls is DiffusionPolicy:
|
||||
kwargs = {"lr_scheduler_num_training_steps": 1}
|
||||
policy_cls(**kwargs)
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_save_and_load_pretrained(policy_name: str):
|
||||
if policy_name == "tdmpc":
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_policy_and_config_classes(policy_name)
|
||||
return
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy: Policy = policy_cls()
|
||||
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir)
|
||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"insert_temporal_dim",
|
||||
[
|
||||
False,
|
||||
True,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
def test_normalize(insert_temporal_dim):
|
||||
"""
|
||||
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
||||
|
||||
Reference in New Issue
Block a user