diff --git a/lerobot/common/policies/auto/configuration_auto.py b/lerobot/common/policies/auto/configuration_auto.py index 90fcdbda6..aed109cc6 100644 --- a/lerobot/common/policies/auto/configuration_auto.py +++ b/lerobot/common/policies/auto/configuration_auto.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict import importlib -from lerobot.configs.policies import PreTrainedConfig -from typing import Type, Union, Dict, Optional +import os +from collections import OrderedDict from pathlib import Path +from typing import Union + +from lerobot.configs.policies import PreTrainedConfig POLICY_CONFIG_NAMES_MAPPING = OrderedDict( [ @@ -24,31 +26,10 @@ POLICY_CONFIG_NAMES_MAPPING = OrderedDict( ] ) -POLICY_NAMES_MAPPING = OrderedDict( - [ - ("act", "ACTPolicy"), - ] -) -def model_type_to_module_name(model_type: str) -> str: - """Convert model type to module name format.""" - return model_type.replace("-", "_") - -def find_policy_type_from_config(config_dict: dict) -> str: - """Find the policy type from a config dictionary.""" - if "type" in config_dict: - return config_dict["type"] - - # Fallback: try to infer from class name - if "policy_class" in config_dict: - for policy_type, class_name in POLICY_CONFIG_NAMES_MAPPING.items(): - if class_name in config_dict["policy_class"]: - return policy_type - - raise ValueError( - "Could not determine policy type from config. " - "Config must contain either 'type' or 'policy_class' field." - ) +def policy_type_to_module_name(policy_type: str) -> str: + """Convert policy type to module name format.""" + return policy_type.replace("-", "_") class _LazyPolicyConfigMapping(OrderedDict): """ @@ -65,10 +46,10 @@ class _LazyPolicyConfigMapping(OrderedDict): return self._extra_content[key] if key not in self._mapping: raise KeyError(key) - + value = self._mapping[key] - module_name = model_type_to_module_name(key) - + module_name = policy_type_to_module_name(key) + # Try standard import path first try: if key not in self._modules: @@ -88,19 +69,17 @@ class _LazyPolicyConfigMapping(OrderedDict): return getattr(self._modules[key], value) except ImportError: continue - - raise ImportError( - f"Could not find configuration class {value} for policy type {key}" - ) + + raise ImportError(f"Could not find configuration class {value} for policy type {key}") def keys(self): return list(self._mapping.keys()) + list(self._extra_content.keys()) def values(self): - return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()) + return [self[k] for k in self._mapping] + list(self._extra_content.values()) def items(self): - return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()) + return [(k, self[k]) for k in self._mapping] + list(self._extra_content.items()) def __iter__(self): return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) @@ -112,35 +91,32 @@ class _LazyPolicyConfigMapping(OrderedDict): """ Register a new configuration in this mapping. """ - if key in self._mapping.keys() and not exist_ok: + if key in self._mapping and not exist_ok: raise ValueError(f"'{key}' is already used by a Policy Config, pick another name.") self._extra_content[key] = value + POLICY_CONFIG_MAPPING = _LazyPolicyConfigMapping(POLICY_CONFIG_NAMES_MAPPING) + class AutoPolicyConfig: """ Factory class for automatically loading policy configurations. - + This class provides methods to: - Load pre-trained policy configurations from local files or the Hub - Register new policy types dynamically - Create policy configurations for specific policy types """ - + def __init__(self): - raise EnvironmentError( + raise OSError( "AutoPolicyConfig is designed to be instantiated " "using the `AutoPolicyConfig.from_pretrained(TODO)` method." ) - + @classmethod - def for_policy( - cls, - policy_type: str, - *args, - **kwargs - ) -> PreTrainedConfig: + def for_policy(cls, policy_type: str, *args, **kwargs) -> PreTrainedConfig: """Create a new configuration instance for the specified policy type.""" if policy_type in POLICY_CONFIG_MAPPING: config_class = POLICY_CONFIG_MAPPING[policy_type] @@ -148,7 +124,7 @@ class AutoPolicyConfig: raise ValueError( f"Unrecognized policy identifier: {policy_type}. Should contain one of {', '.join(POLICY_CONFIG_MAPPING.keys())}" ) - + @staticmethod def register(policy_type, config, exist_ok=False): """ @@ -166,27 +142,9 @@ class AutoPolicyConfig: ) POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok) - @staticmethod - def register_policy_type(policy_type: str, config_class: str, policy_class: str) -> None: - """ - Register a new policy type with its configuration and policy classes. - - Args: - policy_type (`str`): The policy type identifier (e.g., 'act') - config_class (`str`): Name of the configuration class - policy_class (`str`): Name of the policy class - """ - if policy_type in POLICY_CONFIG_NAMES_MAPPING: - raise ValueError(f"Policy type {policy_type} is already registered") - - POLICY_CONFIG_NAMES_MAPPING[policy_type] = config_class - POLICY_NAMES_MAPPING[policy_type] = policy_class - @classmethod def from_pretrained( - cls, - pretrained_policy_config_name_or_path: Union[str, Path], - **kwargs + cls, pretrained_policy_config_name_or_path: Union[str, Path], **kwargs ) -> PreTrainedConfig: """ Instantiate a PreTrainedConfig from a pre-trained policy configuration. @@ -194,9 +152,9 @@ class AutoPolicyConfig: Args: pretrained_policy_config_name_or_path (`str` or `Path`): Can be either: - - A string with the `policy_type` of a pre-trained policy configuration listed on + - A string with the `policy_type` of a pre-trained policy configuration listed on the Hub or locally (e.g., 'act') - - A path to a `directory` containing a configuration file saved + - A path to a `directory` containing a configuration file saved using [`~PreTrainedConfig.save_pretrained`]. - A path or url to a saved configuration JSON `file`. **kwargs: Additional kwargs passed to PreTrainedConfig.from_pretrained() @@ -215,13 +173,35 @@ class AutoPolicyConfig: else: # Assume it's a policy_type identifier policy_type = pretrained_policy_config_name_or_path - + if policy_type not in POLICY_CONFIG_MAPPING: raise ValueError( f"Unrecognized policy type {policy_type}. " f"Should be one of {', '.join(POLICY_CONFIG_MAPPING.keys())}" ) - + config_class = POLICY_CONFIG_MAPPING[policy_type] return config_class.from_pretrained(pretrained_policy_config_name_or_path, **kwargs) + +def main(): + # Simulates a standard policy type being loaded + my_config = AutoPolicyConfig.for_policy("act") + from lerobot.common.policies.act.configuration_act import ACTConfig + assert isinstance(my_config,ACTConfig) + # my_policy = AutoPolicy.from_config(my_config) + # from lerobot.common.policies.act.modeling_act import ACTPolicy + # assert isinstance(my_policy,ACTPolicy) + + # Simulates a new policy type being registered + from lerobot.common.policies.pi0.configuration_pi0 import PI0Config + AutoPolicyConfig.register("pi0", PI0Config) + my_new_config = AutoPolicyConfig.for_policy("pi0") + assert isinstance(my_new_config,PI0Config) + # from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy + # AutoPolicy.register(PI0Config,PI0Policy) + # my_new_policy = AutoPolicy.from_config(my_new_config) + # assert isinstance(my_new_policy,PI0Policy) + +if __name__ == "__main__": + main() \ No newline at end of file