Add regression tests (#119)

- Add `tests/scripts/save_policy_to_safetensor.py` to generate test artifacts
- Add `test_backward_compatibility to test generated outputs from the policies against artifacts
This commit is contained in:
Simon Alibert
2024-05-04 16:20:30 +02:00
committed by GitHub
parent 19812ca470
commit c77633c38c
15 changed files with 236 additions and 43 deletions

View File

@@ -1,8 +1,10 @@
import inspect
from pathlib import Path
import pytest
import torch
from huggingface_hub import PyTorchModelHubMixin
from safetensors.torch import load_file
from lerobot import available_policies
from lerobot.common.datasets.factory import make_dataset
@@ -13,7 +15,8 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
from tests.scripts.save_policy_to_safetensor import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel
@pytest.mark.parametrize("policy_name", available_policies)
@@ -228,3 +231,37 @@ def test_normalize(insert_temporal_dim):
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
new_unnormalize.load_state_dict(unnormalize.state_dict())
unnormalize(output_batch)
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides",
[
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
def test_backward_compatibility(env_name, policy_name, extra_overrides):
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
saved_actions = load_file(env_policy_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
for key in saved_output_dict:
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
for key in saved_grad_stats:
assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
for key in saved_param_stats:
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
for key in saved_actions:
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()