From d4f9807ed02b79a273d14145c756f4c11824c526 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 6 Mar 2025 14:23:00 +0100 Subject: [PATCH] feat(autopolicy): draft v0.4 --- .../policies/auto/configuration_auto.py | 205 ++++++++++-------- lerobot/configs/policies.py | 9 + 2 files changed, 120 insertions(+), 94 deletions(-) diff --git a/lerobot/common/policies/auto/configuration_auto.py b/lerobot/common/policies/auto/configuration_auto.py index 8bb52284..b4671d56 100644 --- a/lerobot/common/policies/auto/configuration_auto.py +++ b/lerobot/common/policies/auto/configuration_auto.py @@ -13,60 +13,65 @@ # limitations under the License. import importlib +import logging import os from collections import OrderedDict from pathlib import Path -from typing import Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.configs.policies import PreTrainedConfig +logger = logging.getLogger(__name__) + +# Constants +IMPORT_PATHS = ["lerobot.common.policies.{0}.configuration_{0}"] + +POLICY_IMPORT_PATHS = ["lerobot.common.policies.{0}.modeling_{0}"] + def policy_type_to_module_name(policy_type: str) -> str: - """Convert policy type to module name format.""" - # TODO(Steven): Deal with this + """ + Convert policy type to module name format. + + Args: + policy_type: The policy type identifier (e.g. 'lerobot/vqbet-pusht') + + Returns: + str: Normalized module name (e.g. 'vqbet') + + Examples: + >>> policy_type_to_module_name("lerobot/vqbet-pusht") + 'vqbet' + """ + # TODO(Steven): This is a temporary solution, we should have a more robust way to handle this return policy_type.replace("lerobot/", "").replace("-", "_").replace("_", "").replace("pusht", "") + class _LazyPolicyConfigMapping(OrderedDict): - """ - A dictionary that lazily load its values when they are requested. - """ - - def __init__(self, mapping): + def __init__(self, mapping: Dict[str, str]): self._mapping = mapping - self._extra_content = {} - self._modules = {} + self._extra_content: Dict[str, Any] = {} + self._modules: Dict[str, Any] = {} - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: if key in self._extra_content: return self._extra_content[key] if key not in self._mapping: - raise KeyError(key) + raise KeyError(f"Policy type '{key}' not found in mapping") value = self._mapping[key] module_name = policy_type_to_module_name(key) - # Try standard import path first - try: - if key not in self._modules: - print("Importing CONFIG: ",module_name) - self._modules[key] = importlib.import_module( - f"lerobot.common.policies.{module_name}.configuration_{module_name}" - ) - return getattr(self._modules[key], value) - except (ImportError, AttributeError): - # Try fallback paths - for import_path in [ - f"lerobot.policies.{module_name}", - f"lerobot.common.policies.{module_name}", - ]: - try: - print("Importing CONFIG: ",module_name) - self._modules[key] = importlib.import_module(import_path) - if hasattr(self._modules[key], value): - return getattr(self._modules[key], value) - except ImportError: - continue + for import_path in IMPORT_PATHS: + try: + if key not in self._modules: + self._modules[key] = importlib.import_module(import_path.format(module_name)) + logger.debug(f"Config module: {module_name} imported") + if hasattr(self._modules[key], value): + return getattr(self._modules[key], value) + except ImportError: + continue raise ImportError(f"Could not find configuration class {value} for policy type {key}") @@ -109,12 +114,12 @@ class _LazyPolicyMapping(OrderedDict): A dictionary that lazily loads its values when they are requested. """ - def __init__(self, mapping): + def __init__(self, mapping: Dict[str, str]): self._mapping = mapping - self._extra_content = {} - self._modules = {} - self._config_mapping = {} # Maps config classes to policy classes - self._initialized_types = set() # Track which types have been initialized + self._extra_content: Dict[str, Type[PreTrainedPolicy]] = {} + self._modules: Dict[str, Any] = {} + self._config_mapping: Dict[Type[PreTrainedConfig], Type[PreTrainedPolicy]] = {} + self._initialized_types: set[str] = set() def _lazy_init_for_type(self, policy_type: str) -> None: """Lazily initialize mappings for a policy type if not already done.""" @@ -124,41 +129,34 @@ class _LazyPolicyMapping(OrderedDict): self._config_mapping[config_class] = self[policy_type] self._initialized_types.add(policy_type) except (ImportError, AttributeError, KeyError) as e: - import logging - logging.warning( - f"Could not automatically map config for policy type {policy_type}: {str(e)}" - ) + logger.warning(f"Could not automatically map config for policy type {policy_type}: {str(e)}") - def __getitem__(self, key): + def __getitem__(self, key: str) -> Type[PreTrainedPolicy]: + """Get a policy class by key with lazy loading.""" if key in self._extra_content: return self._extra_content[key] if key not in self._mapping: - raise KeyError(key) + raise KeyError(f"Policy type '{key}' not found in mapping") value = self._mapping[key] module_name = policy_type_to_module_name(key) - try: - if key not in self._modules: - print("Importing POLICY: ", module_name) - self._modules[key] = importlib.import_module( - f"lerobot.common.policies.{module_name}.modeling_{module_name}" - ) - return getattr(self._modules[key], value) - except (ImportError, AttributeError): - for import_path in [ - f"lerobot.policies.{module_name}", - f"lerobot.common.policies.{module_name}", - ]: - try: - print("Importing POLICY: ",module_name) - self._modules[key] = importlib.import_module(import_path) - if hasattr(self._modules[key], value): - return getattr(self._modules[key], value) - except ImportError: - continue + for import_path in POLICY_IMPORT_PATHS: + try: + if key not in self._modules: + self._modules[key] = importlib.import_module(import_path.format(module_name)) + logger.debug( + f"Policy module: {module_name} imported from {import_path.format(module_name)}" + ) + if hasattr(self._modules[key], value): + return getattr(self._modules[key], value) + except ImportError: + continue - raise ImportError(f"Could not find policy class {value} for policy type {key}") + raise ImportError( + f"Could not find policy class {value} for policy type {key}. " + f"Tried paths: {[p.format(module_name) for p in POLICY_IMPORT_PATHS]}" + ) def register( self, @@ -166,26 +164,33 @@ class _LazyPolicyMapping(OrderedDict): value: Type[PreTrainedPolicy], config_class: Type[PreTrainedConfig], exist_ok: bool = False, - ): + ) -> None: """Register a new policy class with its configuration class.""" + if not isinstance(key, str): + raise TypeError(f"Key must be a string, got {type(key)}") + if not issubclass(value, PreTrainedPolicy): + raise TypeError(f"Value must be a PreTrainedPolicy subclass, got {type(value)}") + if not issubclass(config_class, PreTrainedConfig): + raise TypeError(f"Config class must be a PreTrainedConfig subclass, got {type(config_class)}") + if key in self._mapping and not exist_ok: raise ValueError(f"'{key}' is already used by a Policy, pick another name.") self._extra_content[key] = value self._config_mapping[config_class] = value - def get_policy_for_config(self, config_class: PreTrainedConfig) -> Type[PreTrainedPolicy]: + def get_policy_for_config(self, config_class: Type[PreTrainedConfig]) -> Type[PreTrainedPolicy]: """Get the policy class associated with a config class.""" # First check direct config class mapping - if type(config_class) in self._config_mapping: - return self._config_mapping[type(config_class)] + if config_class in self._config_mapping: + return self._config_mapping[config_class] # Try to find by policy type try: - policy_type = config_class.type + policy_type = config_class.get_type_str() # Check extra content first if policy_type in self._extra_content: return self._extra_content[policy_type] - + # Then check standard mapping if policy_type in self._mapping: self._lazy_init_for_type(policy_type) @@ -222,10 +227,7 @@ class AutoPolicyConfig: """ def __init__(self): - raise OSError( - "AutoPolicyConfig is designed to be instantiated " - "using the `AutoPolicyConfig.from_pretrained(TODO)` method." - ) + raise OSError("AutoPolicyConfig not meant to be instantiated directly") @classmethod def for_policy(cls, policy_type: str, *args, **kwargs) -> PreTrainedConfig: @@ -246,13 +248,12 @@ class AutoPolicyConfig: policy_type (`str`): The policy type like "act" or "pi0". config ([`PreTrainedConfig`]): The config to register. """ - # TODO(Steven): config.type doesn't work at this stage because it is not an instance, it the class definition - # if issubclass(config, PreTrainedConfig) and config.type != policy_type: - # raise ValueError( - # "The config you are passing has a `policy_type` attribute that is not consistent with the policy type " - # f"you passed (config has {config.type} and you passed {policy_type}. Fix one of those so they " - # "match!" - # ) + if issubclass(config, PreTrainedConfig) and config.get_type_str() != policy_type: + raise ValueError( + "The config you are passing has a `policy_type` attribute that is not consistent with the policy type " + f"you passed (config has {config.type} and you passed {policy_type}. Fix one of those so they " + "match!" + ) POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok) @classmethod @@ -308,15 +309,12 @@ class AutoPolicy: """ def __init__(self): - raise OSError( - "AutoPolicy is designed to be instantiated using the " - "`AutoPolicy.from_config()` or `AutoPolicy.from_pretrained()` methods." - ) + raise OSError("AutoPolicy not meant to be instantiated directly") @classmethod def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy: """Instantiate a policy from a configuration.""" - policy_class = POLICY_MAPPING.get_policy_for_config(config) + policy_class = POLICY_MAPPING.get_policy_for_config(type(config)) return policy_class(config, **kwargs) @classmethod @@ -356,16 +354,33 @@ class AutoPolicy: policy_class: The policy class to register exist_ok: Whether to allow overwriting existing registrations """ - POLICY_MAPPING.register(config_class.type, policy_class, config_class, exist_ok=exist_ok) + POLICY_MAPPING.register(config_class.get_type_str(), policy_class, config_class, exist_ok=exist_ok) def main(): - - # Simulates a build-in policy type being loaded - # Built-in policies work without explicit registration + """Test the AutoPolicy and AutoPolicyConfig functionality.""" - # config = AutoPolicyConfig.for_policy("vqbet") - config = AutoPolicyConfig.from_pretrained("lerobot/vqbet_pusht") + def test_error_cases(): + """Test error handling""" + try: + AutoPolicyConfig() + except OSError as e: + assert "not meant to be instantiated directly" in str(e) + try: + AutoPolicy() + except OSError as e: + assert "not meant to be instantiated directly" in str(e) + + # try: + # AutoPolicy.from_config("invalid_config") + # except ValueError as e: + # assert "Unrecognized policy identifier" in str(e) + + logging.basicConfig(level=logging.DEBUG) + + # Test built-in policy loading + # config = AutoPolicyConfig.from_pretrained("lerobot/vqbet_pusht") + config = AutoPolicyConfig.for_policy("vqbet") policy = AutoPolicy.from_config(config) from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig @@ -374,8 +389,7 @@ def main(): assert isinstance(config, VQBeTConfig) assert isinstance(policy, VQBeTPolicy) - # Simulates a new policy type being registered - # Only new policies need registration + # Test policy registration from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy @@ -387,6 +401,9 @@ def main(): assert isinstance(my_new_config, TDMPCConfig) assert isinstance(my_new_policy, TDMPCPolicy) + # Run error case tests + test_error_cases() + if __name__ == "__main__": main() diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py index 9b5a7c5c..266b159b 100644 --- a/lerobot/configs/policies.py +++ b/lerobot/configs/policies.py @@ -47,6 +47,15 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): def type(self) -> str: return self.get_choice_name(self.__class__) + # TODO(Steven): Find a better way to do deal with this + @classmethod + def get_type_str(cls) -> str: + """Get the policy type identifier for this configuration class.""" + class_name = cls.__name__.lower() + if class_name.endswith("config"): + return class_name[:-6] # Remove 'config' suffix + return class_name + @abc.abstractproperty def observation_delta_indices(self) -> list | None: raise NotImplementedError