Add test to make sure policy dataclass configs match yaml configs (#292)

This commit is contained in:
Alexander Soare
2024-06-26 09:09:40 +01:00
committed by GitHub
parent 7d1542cae1
commit 342f429f1c
5 changed files with 54 additions and 13 deletions

View File

@@ -108,16 +108,23 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
if __name__ == "__main__":
env_policies = [
# ("xarm", "tdmpc", []),
# ("xarm", "tdmpc", ["policy.use_mpc=false"], ""),
# (
# "pusht",
# "diffusion",
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
# [
# "policy.n_action_steps=8",
# "policy.num_inference_steps=10",
# "policy.down_dims=[128, 256, 512]",
# ],
# "",
# ),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
]
if len(env_policies) == 0:
raise RuntimeError("No policies were provided!")
for env, policy, extra_overrides, file_name_extra in env_policies:
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra

View File

@@ -26,7 +26,11 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
from lerobot.common.policies.factory import (
_policy_cfg_from_hydra_cfg,
get_policy_and_config_classes,
make_policy,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
@@ -210,6 +214,23 @@ def test_policy_defaults(policy_name: str):
policy_cls()
@pytest.mark.parametrize(
"env_name,policy_name",
[
("xarm", "tdmpc"),
("pusht", "diffusion"),
("aloha", "act"),
],
)
def test_yaml_matches_dataclass(env_name: str, policy_name: str):
"""Check that dataclass configs match their respective yaml configs."""
hydra_cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"])
_, policy_cfg_cls = get_policy_and_config_classes(policy_name)
policy_cfg_from_hydra = _policy_cfg_from_hydra_cfg(policy_cfg_cls, hydra_cfg)
policy_cfg_from_dataclass = policy_cfg_cls()
assert policy_cfg_from_hydra == policy_cfg_from_dataclass
@pytest.mark.parametrize("policy_name", available_policies)
def test_save_and_load_pretrained(policy_name: str):
policy_cls, _ = get_policy_and_config_classes(policy_name)
@@ -318,7 +339,10 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides, file_name_extra",
[
("xarm", "tdmpc", [], ""),
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
("xarm", "tdmpc", ["policy.use_mpc=false"], ""),
(
"pusht",
"diffusion",
@@ -342,7 +366,8 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for.
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
should be updated.
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`.