Add test to make sure policy dataclass configs match yaml configs (#292)

This commit is contained in:
Alexander Soare
2024-06-26 09:09:40 +01:00
committed by GitHub
parent 7d1542cae1
commit 342f429f1c
5 changed files with 54 additions and 13 deletions

View File

@@ -26,7 +26,11 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
from lerobot.common.policies.factory import (
_policy_cfg_from_hydra_cfg,
get_policy_and_config_classes,
make_policy,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
@@ -210,6 +214,23 @@ def test_policy_defaults(policy_name: str):
policy_cls()
@pytest.mark.parametrize(
"env_name,policy_name",
[
("xarm", "tdmpc"),
("pusht", "diffusion"),
("aloha", "act"),
],
)
def test_yaml_matches_dataclass(env_name: str, policy_name: str):
"""Check that dataclass configs match their respective yaml configs."""
hydra_cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"])
_, policy_cfg_cls = get_policy_and_config_classes(policy_name)
policy_cfg_from_hydra = _policy_cfg_from_hydra_cfg(policy_cfg_cls, hydra_cfg)
policy_cfg_from_dataclass = policy_cfg_cls()
assert policy_cfg_from_hydra == policy_cfg_from_dataclass
@pytest.mark.parametrize("policy_name", available_policies)
def test_save_and_load_pretrained(policy_name: str):
policy_cls, _ = get_policy_and_config_classes(policy_name)
@@ -318,7 +339,10 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides, file_name_extra",
[
("xarm", "tdmpc", [], ""),
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
("xarm", "tdmpc", ["policy.use_mpc=false"], ""),
(
"pusht",
"diffusion",
@@ -342,7 +366,8 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for.
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
should be updated.
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`.