#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import logging from omegaconf import DictConfig, OmegaConf from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_safe_torch_device def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) if not set(hydra_cfg.policy).issuperset(expected_kwargs): logging.warning( f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" ) policy_cfg = policy_cfg_class( **{ k: v for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items() if k in expected_kwargs } ) return policy_cfg def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" if name == "tdmpc": from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy return TDMPCPolicy, TDMPCConfig elif name == "diffusion": from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy return DiffusionPolicy, DiffusionConfig elif name == "act": from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.act.modeling_act import ACTPolicy return ACTPolicy, ACTConfig else: raise NotImplementedError(f"Policy with name {name} is not implemented.") def make_policy( hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None ) -> Policy: """Make an instance of a policy class. Args: hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is provided, only `hydra_cfg.policy.name` is used while everything else is ignored. pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a directory containing weights saved using `Policy.save_pretrained`. Note that providing this argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`. dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must be provided when initializing a new policy, and must not be provided when loading a pretrained policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`. """ if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None): raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.") policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) if pretrained_policy_name_or_path is None: # Make a fresh policy. policy = policy_cls(policy_cfg, dataset_stats) else: # Load a pretrained policy and override the config if needed (for example, if there are inference-time # hyperparameters that we want to vary). # TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained # weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should # make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274. policy = policy_cls(policy_cfg) policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()) policy.to(get_safe_torch_device(hydra_cfg.device)) return policy