From 8f98672eccb5dcbfd91453248537f040999c0965 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 5 Mar 2025 23:12:55 +0100 Subject: [PATCH] feat(autopolicy): draft v0.2 --- .../policies/auto/configuration_auto.py | 209 +++++++++++++++--- 1 file changed, 183 insertions(+), 26 deletions(-) diff --git a/lerobot/common/policies/auto/configuration_auto.py b/lerobot/common/policies/auto/configuration_auto.py index aed109cc6..649d6e197 100644 --- a/lerobot/common/policies/auto/configuration_auto.py +++ b/lerobot/common/policies/auto/configuration_auto.py @@ -16,21 +16,17 @@ import importlib import os from collections import OrderedDict from pathlib import Path -from typing import Union +from typing import Optional, Type, Union +from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.configs.policies import PreTrainedConfig -POLICY_CONFIG_NAMES_MAPPING = OrderedDict( - [ - ("act", "ACTConfig"), - ] -) - def policy_type_to_module_name(policy_type: str) -> str: """Convert policy type to module name format.""" return policy_type.replace("-", "_") + class _LazyPolicyConfigMapping(OrderedDict): """ A dictionary that lazily load its values when they are requested. @@ -96,9 +92,100 @@ class _LazyPolicyConfigMapping(OrderedDict): self._extra_content[key] = value +POLICY_CONFIG_NAMES_MAPPING = OrderedDict( + [ + ("act", "ACTConfig"), + ] +) + POLICY_CONFIG_MAPPING = _LazyPolicyConfigMapping(POLICY_CONFIG_NAMES_MAPPING) +class _LazyPolicyMapping(OrderedDict): + """ + A dictionary that lazily loads its values when they are requested. + """ + + def __init__(self, mapping): + self._mapping = mapping + self._extra_content = {} + self._modules = {} + self._config_mapping = {} # Maps config classes to policy classes + + # Automatically set up mappings for built-in policies using POLICY_CONFIG_MAPPING + for policy_type in self._mapping: + try: + config_class = POLICY_CONFIG_MAPPING[policy_type] + self._config_mapping[config_class] = self[policy_type] + except (ImportError, AttributeError, KeyError) as e: + import logging + + logging.warning(f"Could not automatically map config for policy type {policy_type}: {str(e)}") + + def __getitem__(self, key): + if key in self._extra_content: + return self._extra_content[key] + if key not in self._mapping: + raise KeyError(key) + + value = self._mapping[key] + module_name = policy_type_to_module_name(key) + + try: + if key not in self._modules: + self._modules[key] = importlib.import_module( + f"lerobot.common.policies.{module_name}.modeling_{module_name}" + ) + return getattr(self._modules[key], value) + except (ImportError, AttributeError): + for import_path in [ + f"lerobot.policies.{module_name}", + f"lerobot.common.policies.{module_name}", + ]: + try: + self._modules[key] = importlib.import_module(import_path) + if hasattr(self._modules[key], value): + return getattr(self._modules[key], value) + except ImportError: + continue + + raise ImportError(f"Could not find policy class {value} for policy type {key}") + + def register( + self, + key: str, + value: Type[PreTrainedPolicy], + config_class: Type[PreTrainedConfig], + exist_ok: bool = False, + ): + """Register a new policy class with its configuration class.""" + if key in self._mapping and not exist_ok: + raise ValueError(f"'{key}' is already used by a Policy, pick another name.") + self._extra_content[key] = value + self._config_mapping[config_class] = value + + def get_policy_for_config(self, config_class: Type[PreTrainedConfig]) -> Type[PreTrainedPolicy]: + """Get the policy class associated with a config class.""" + if config_class in self._config_mapping: + return self._config_mapping[config_class] + + # Try to find by policy type + policy_type = config_class.type + if policy_type in self: + return self[policy_type] + + raise ValueError(f"No policy class found for config class {config_class.__name__}") + + +POLICY_NAMES_MAPPING = OrderedDict( + [ + ("act", "ACTPolicy"), + ] +) + +POLICY_MAPPING = _LazyPolicyMapping(POLICY_NAMES_MAPPING) + + class AutoPolicyConfig: """ Factory class for automatically loading policy configurations. @@ -134,12 +221,12 @@ class AutoPolicyConfig: policy_type (`str`): The policy type like "act" or "pi0". config ([`PreTrainedConfig`]): The config to register. """ - if issubclass(config, PreTrainedConfig) and config.policy_type != policy_type: - raise ValueError( - "The config you are passing has a `policy_type` attribute that is not consistent with the policy type " - f"you passed (config has {config.policy_type} and you passed {policy_type}. Fix one of those so they " - "match!" - ) + # if issubclass(config, PreTrainedConfig) and config.policy_type != policy_type: + # raise ValueError( + # "The config you are passing has a `policy_type` attribute that is not consistent with the policy type " + # f"you passed (config has {config.policy_type} and you passed {policy_type}. Fix one of those so they " + # "match!" + # ) POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok) @classmethod @@ -184,24 +271,94 @@ class AutoPolicyConfig: return config_class.from_pretrained(pretrained_policy_config_name_or_path, **kwargs) +class AutoPolicy: + """ + Factory class that allows instantiating policy models from configurations. + + This class provides methods to: + - Load pre-trained policies from configurations + - Register new policy types dynamically + - Create policy instances for specific configurations + """ + + def __init__(self): + raise OSError( + "AutoPolicy is designed to be instantiated using the " + "`AutoPolicy.from_config()` or `AutoPolicy.from_pretrained()` methods." + ) + + @classmethod + def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy: + """Instantiate a policy from a configuration.""" + policy_class = POLICY_MAPPING.get_policy_for_config(config.__class__) + return policy_class(config, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_policy_name_or_path: Union[str, Path], + *, + config: Optional[PreTrainedConfig] = None, + **kwargs, + ) -> PreTrainedPolicy: + """ + Instantiate a pre-trained policy from a configuration. + + Args: + pretrained_policy_name_or_path: Path to pretrained weights or model identifier + config: Optional configuration for the policy + **kwargs: Additional arguments to pass to from_pretrained() + """ + if config is None: + config = AutoPolicyConfig.from_pretrained(pretrained_policy_name_or_path) + + if isinstance(config, str): + config = AutoPolicyConfig.from_pretrained(config) + + policy_class = POLICY_MAPPING.get_policy_for_config(config.__class__) + return policy_class.from_pretrained(pretrained_policy_name_or_path, config=config, **kwargs) + + @staticmethod + def register( + config_class: Type[PreTrainedConfig], policy_class: Type[PreTrainedPolicy], exist_ok: bool = False + ): + """ + Register a new policy class for a configuration class. + + Args: + config_class: The configuration class + policy_class: The policy class to register + exist_ok: Whether to allow overwriting existing registrations + """ + POLICY_MAPPING.register(config_class.type, policy_class, config_class, exist_ok=exist_ok) + + def main(): - # Simulates a standard policy type being loaded + #TODO: Pass the needed arguments to the policies + + # Simulates a build-in policy type being loaded + # Built-in policies work without explicit registration my_config = AutoPolicyConfig.for_policy("act") - from lerobot.common.policies.act.configuration_act import ACTConfig - assert isinstance(my_config,ACTConfig) # my_policy = AutoPolicy.from_config(my_config) - # from lerobot.common.policies.act.modeling_act import ACTPolicy - # assert isinstance(my_policy,ACTPolicy) + + from lerobot.common.policies.act.configuration_act import ACTConfig + + assert isinstance(my_config, ACTConfig) + # assert isinstance(my_policy, ACTPolicy) # Simulates a new policy type being registered - from lerobot.common.policies.pi0.configuration_pi0 import PI0Config - AutoPolicyConfig.register("pi0", PI0Config) - my_new_config = AutoPolicyConfig.for_policy("pi0") - assert isinstance(my_new_config,PI0Config) - # from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy - # AutoPolicy.register(PI0Config,PI0Policy) + # Only new policies need registration + from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig + from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy + + AutoPolicyConfig.register("diffusion", DiffusionConfig) + AutoPolicy.register(DiffusionConfig, DiffusionPolicy) + + my_new_config = AutoPolicyConfig.for_policy("diffusion") # my_new_policy = AutoPolicy.from_config(my_new_config) - # assert isinstance(my_new_policy,PI0Policy) + assert isinstance(my_new_config, DiffusionConfig) + # assert isinstance(my_new_policy, DiffusionPolicy) + if __name__ == "__main__": - main() \ No newline at end of file + main()