From 342f429f1c321a2b4501c3007b1dacba7244b469 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 26 Jun 2024 09:09:40 +0100 Subject: [PATCH] Add test to make sure policy dataclass configs match yaml configs (#292) --- lerobot/common/policies/factory.py | 19 +++++++++---- lerobot/configs/policy/diffusion.yaml | 2 +- lerobot/configs/policy/tdmpc.yaml | 2 +- tests/scripts/save_policy_to_safetensors.py | 13 +++++++-- tests/test_policies.py | 31 +++++++++++++++++++-- 5 files changed, 54 insertions(+), 13 deletions(-) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 124e8c68..5cb2fd52 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -28,9 +28,15 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): logging.warning( f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" ) + + # OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid + # issues with mutable defaults. This filter changes all lists to tuples. + def list_to_tuple(item): + return tuple(item) if isinstance(item, list) else item + policy_cfg = policy_cfg_class( **{ - k: v + k: list_to_tuple(v) for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items() if k in expected_kwargs } @@ -80,7 +86,9 @@ def make_policy( policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`. """ if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None): - raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.") + raise ValueError( + "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided." + ) policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) @@ -91,9 +99,10 @@ def make_policy( else: # Load a pretrained policy and override the config if needed (for example, if there are inference-time # hyperparameters that we want to vary). - # TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained - # weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should - # make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274. + # TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, + # pretrained weights which are then loaded into a fresh policy with the desired config. This PR in + # huggingface_hub should make it possible to avoid the hack: + # https://github.com/huggingface/huggingface_hub/pull/2274. policy = policy_cls(policy_cfg) policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index b04ecf1b..95cc75b6 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -99,7 +99,7 @@ policy: clip_sample_range: 1.0 # Inference - num_inference_steps: 100 + num_inference_steps: null # if not provided, defaults to `num_train_timesteps` # Loss computation do_mask_loss_for_padding: false diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 09326ab4..4e55ddf7 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -54,7 +54,7 @@ policy: discount: 0.9 # Inference. - use_mpc: false + use_mpc: true cem_iterations: 6 max_std: 2.0 min_std: 0.05 diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 5fead55a..67308bb3 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -108,16 +108,23 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": env_policies = [ - # ("xarm", "tdmpc", []), + # ("xarm", "tdmpc", ["policy.use_mpc=false"], ""), # ( # "pusht", # "diffusion", - # ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], + # [ + # "policy.n_action_steps=8", + # "policy.num_inference_steps=10", + # "policy.down_dims=[128, 256, 512]", + # ], + # "", # ), - ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"), + # ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"), # ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), # ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), ] + if len(env_policies) == 0: + raise RuntimeError("No policies were provided!") for env, policy, extra_overrides, file_name_extra in env_policies: save_policy_to_safetensors( "tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra diff --git a/tests/test_policies.py b/tests/test_policies.py index 490c25cc..bc9c34ff 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -26,7 +26,11 @@ from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import preprocess_observation -from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy +from lerobot.common.policies.factory import ( + _policy_cfg_from_hydra_cfg, + get_policy_and_config_classes, + make_policy, +) from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config @@ -210,6 +214,23 @@ def test_policy_defaults(policy_name: str): policy_cls() +@pytest.mark.parametrize( + "env_name,policy_name", + [ + ("xarm", "tdmpc"), + ("pusht", "diffusion"), + ("aloha", "act"), + ], +) +def test_yaml_matches_dataclass(env_name: str, policy_name: str): + """Check that dataclass configs match their respective yaml configs.""" + hydra_cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"]) + _, policy_cfg_cls = get_policy_and_config_classes(policy_name) + policy_cfg_from_hydra = _policy_cfg_from_hydra_cfg(policy_cfg_cls, hydra_cfg) + policy_cfg_from_dataclass = policy_cfg_cls() + assert policy_cfg_from_hydra == policy_cfg_from_dataclass + + @pytest.mark.parametrize("policy_name", available_policies) def test_save_and_load_pretrained(policy_name: str): policy_cls, _ = get_policy_and_config_classes(policy_name) @@ -318,7 +339,10 @@ def test_normalize(insert_temporal_dim): @pytest.mark.parametrize( "env_name, policy_name, extra_overrides, file_name_extra", [ - ("xarm", "tdmpc", [], ""), + # TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it + # was changed to true. For some reason, tests would pass locally, but not in CI. So here we override + # to test with `policy.use_mpc=false`. + ("xarm", "tdmpc", ["policy.use_mpc=false"], ""), ( "pusht", "diffusion", @@ -342,7 +366,8 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam include a report on what changed and how that affected the outputs. 2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and add the policies you want to update the test artifacts for. - 3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. + 3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact + should be updated. 4. Check that this test now passes. 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 6. Remember to stage and commit the resulting changes to `tests/data`.