Add test to make sure policy dataclass configs match yaml configs (#292)
This commit is contained in:
@@ -26,7 +26,11 @@ from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
|
||||
from lerobot.common.policies.factory import (
|
||||
_policy_cfg_from_hydra_cfg,
|
||||
get_policy_and_config_classes,
|
||||
make_policy,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
@@ -210,6 +214,23 @@ def test_policy_defaults(policy_name: str):
|
||||
policy_cls()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name",
|
||||
[
|
||||
("xarm", "tdmpc"),
|
||||
("pusht", "diffusion"),
|
||||
("aloha", "act"),
|
||||
],
|
||||
)
|
||||
def test_yaml_matches_dataclass(env_name: str, policy_name: str):
|
||||
"""Check that dataclass configs match their respective yaml configs."""
|
||||
hydra_cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"])
|
||||
_, policy_cfg_cls = get_policy_and_config_classes(policy_name)
|
||||
policy_cfg_from_hydra = _policy_cfg_from_hydra_cfg(policy_cfg_cls, hydra_cfg)
|
||||
policy_cfg_from_dataclass = policy_cfg_cls()
|
||||
assert policy_cfg_from_hydra == policy_cfg_from_dataclass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_save_and_load_pretrained(policy_name: str):
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
@@ -318,7 +339,10 @@ def test_normalize(insert_temporal_dim):
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, policy_name, extra_overrides, file_name_extra",
|
||||
[
|
||||
("xarm", "tdmpc", [], ""),
|
||||
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
|
||||
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
||||
# to test with `policy.use_mpc=false`.
|
||||
("xarm", "tdmpc", ["policy.use_mpc=false"], ""),
|
||||
(
|
||||
"pusht",
|
||||
"diffusion",
|
||||
@@ -342,7 +366,8 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||
include a report on what changed and how that affected the outputs.
|
||||
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
||||
add the policies you want to update the test artifacts for.
|
||||
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
|
||||
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
||||
should be updated.
|
||||
4. Check that this test now passes.
|
||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
6. Remember to stage and commit the resulting changes to `tests/data`.
|
||||
|
||||
Reference in New Issue
Block a user