Add test to make sure policy dataclass configs match yaml configs (#292)

This commit is contained in:
Alexander Soare
2024-06-26 09:09:40 +01:00
committed by GitHub
parent 7d1542cae1
commit 342f429f1c
5 changed files with 54 additions and 13 deletions

View File

@@ -28,9 +28,15 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
logging.warning(
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
)
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
# issues with mutable defaults. This filter changes all lists to tuples.
def list_to_tuple(item):
return tuple(item) if isinstance(item, list) else item
policy_cfg = policy_cfg_class(
**{
k: v
k: list_to_tuple(v)
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
if k in expected_kwargs
}
@@ -80,7 +86,9 @@ def make_policy(
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
"""
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.")
raise ValueError(
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
)
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
@@ -91,9 +99,10 @@ def make_policy(
else:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
# huggingface_hub should make it possible to avoid the hack:
# https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())

View File

@@ -99,7 +99,7 @@ policy:
clip_sample_range: 1.0
# Inference
num_inference_steps: 100
num_inference_steps: null # if not provided, defaults to `num_train_timesteps`
# Loss computation
do_mask_loss_for_padding: false

View File

@@ -54,7 +54,7 @@ policy:
discount: 0.9
# Inference.
use_mpc: false
use_mpc: true
cem_iterations: 6
max_std: 2.0
min_std: 0.05