feat(autopolicy): draft v0.2

This commit is contained in:
Steven Palma
2025-03-05 23:12:55 +01:00
parent 78df84f758
commit 8f98672ecc

View File

@@ -16,21 +16,17 @@ import importlib
import os
from collections import OrderedDict
from pathlib import Path
from typing import Union
from typing import Optional, Type, Union
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.configs.policies import PreTrainedConfig
POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
[
("act", "ACTConfig"),
]
)
def policy_type_to_module_name(policy_type: str) -> str:
"""Convert policy type to module name format."""
return policy_type.replace("-", "_")
class _LazyPolicyConfigMapping(OrderedDict):
"""
A dictionary that lazily load its values when they are requested.
@@ -96,9 +92,100 @@ class _LazyPolicyConfigMapping(OrderedDict):
self._extra_content[key] = value
POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
[
("act", "ACTConfig"),
]
)
POLICY_CONFIG_MAPPING = _LazyPolicyConfigMapping(POLICY_CONFIG_NAMES_MAPPING)
class _LazyPolicyMapping(OrderedDict):
"""
A dictionary that lazily loads its values when they are requested.
"""
def __init__(self, mapping):
self._mapping = mapping
self._extra_content = {}
self._modules = {}
self._config_mapping = {} # Maps config classes to policy classes
# Automatically set up mappings for built-in policies using POLICY_CONFIG_MAPPING
for policy_type in self._mapping:
try:
config_class = POLICY_CONFIG_MAPPING[policy_type]
self._config_mapping[config_class] = self[policy_type]
except (ImportError, AttributeError, KeyError) as e:
import logging
logging.warning(f"Could not automatically map config for policy type {policy_type}: {str(e)}")
def __getitem__(self, key):
if key in self._extra_content:
return self._extra_content[key]
if key not in self._mapping:
raise KeyError(key)
value = self._mapping[key]
module_name = policy_type_to_module_name(key)
try:
if key not in self._modules:
self._modules[key] = importlib.import_module(
f"lerobot.common.policies.{module_name}.modeling_{module_name}"
)
return getattr(self._modules[key], value)
except (ImportError, AttributeError):
for import_path in [
f"lerobot.policies.{module_name}",
f"lerobot.common.policies.{module_name}",
]:
try:
self._modules[key] = importlib.import_module(import_path)
if hasattr(self._modules[key], value):
return getattr(self._modules[key], value)
except ImportError:
continue
raise ImportError(f"Could not find policy class {value} for policy type {key}")
def register(
self,
key: str,
value: Type[PreTrainedPolicy],
config_class: Type[PreTrainedConfig],
exist_ok: bool = False,
):
"""Register a new policy class with its configuration class."""
if key in self._mapping and not exist_ok:
raise ValueError(f"'{key}' is already used by a Policy, pick another name.")
self._extra_content[key] = value
self._config_mapping[config_class] = value
def get_policy_for_config(self, config_class: Type[PreTrainedConfig]) -> Type[PreTrainedPolicy]:
"""Get the policy class associated with a config class."""
if config_class in self._config_mapping:
return self._config_mapping[config_class]
# Try to find by policy type
policy_type = config_class.type
if policy_type in self:
return self[policy_type]
raise ValueError(f"No policy class found for config class {config_class.__name__}")
POLICY_NAMES_MAPPING = OrderedDict(
[
("act", "ACTPolicy"),
]
)
POLICY_MAPPING = _LazyPolicyMapping(POLICY_NAMES_MAPPING)
class AutoPolicyConfig:
"""
Factory class for automatically loading policy configurations.
@@ -134,12 +221,12 @@ class AutoPolicyConfig:
policy_type (`str`): The policy type like "act" or "pi0".
config ([`PreTrainedConfig`]): The config to register.
"""
if issubclass(config, PreTrainedConfig) and config.policy_type != policy_type:
raise ValueError(
"The config you are passing has a `policy_type` attribute that is not consistent with the policy type "
f"you passed (config has {config.policy_type} and you passed {policy_type}. Fix one of those so they "
"match!"
)
# if issubclass(config, PreTrainedConfig) and config.policy_type != policy_type:
# raise ValueError(
# "The config you are passing has a `policy_type` attribute that is not consistent with the policy type "
# f"you passed (config has {config.policy_type} and you passed {policy_type}. Fix one of those so they "
# "match!"
# )
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
@classmethod
@@ -184,24 +271,94 @@ class AutoPolicyConfig:
return config_class.from_pretrained(pretrained_policy_config_name_or_path, **kwargs)
class AutoPolicy:
"""
Factory class that allows instantiating policy models from configurations.
This class provides methods to:
- Load pre-trained policies from configurations
- Register new policy types dynamically
- Create policy instances for specific configurations
"""
def __init__(self):
raise OSError(
"AutoPolicy is designed to be instantiated using the "
"`AutoPolicy.from_config()` or `AutoPolicy.from_pretrained()` methods."
)
@classmethod
def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy:
"""Instantiate a policy from a configuration."""
policy_class = POLICY_MAPPING.get_policy_for_config(config.__class__)
return policy_class(config, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_policy_name_or_path: Union[str, Path],
*,
config: Optional[PreTrainedConfig] = None,
**kwargs,
) -> PreTrainedPolicy:
"""
Instantiate a pre-trained policy from a configuration.
Args:
pretrained_policy_name_or_path: Path to pretrained weights or model identifier
config: Optional configuration for the policy
**kwargs: Additional arguments to pass to from_pretrained()
"""
if config is None:
config = AutoPolicyConfig.from_pretrained(pretrained_policy_name_or_path)
if isinstance(config, str):
config = AutoPolicyConfig.from_pretrained(config)
policy_class = POLICY_MAPPING.get_policy_for_config(config.__class__)
return policy_class.from_pretrained(pretrained_policy_name_or_path, config=config, **kwargs)
@staticmethod
def register(
config_class: Type[PreTrainedConfig], policy_class: Type[PreTrainedPolicy], exist_ok: bool = False
):
"""
Register a new policy class for a configuration class.
Args:
config_class: The configuration class
policy_class: The policy class to register
exist_ok: Whether to allow overwriting existing registrations
"""
POLICY_MAPPING.register(config_class.type, policy_class, config_class, exist_ok=exist_ok)
def main():
# Simulates a standard policy type being loaded
#TODO: Pass the needed arguments to the policies
# Simulates a build-in policy type being loaded
# Built-in policies work without explicit registration
my_config = AutoPolicyConfig.for_policy("act")
from lerobot.common.policies.act.configuration_act import ACTConfig
assert isinstance(my_config,ACTConfig)
# my_policy = AutoPolicy.from_config(my_config)
# from lerobot.common.policies.act.modeling_act import ACTPolicy
# assert isinstance(my_policy,ACTPolicy)
from lerobot.common.policies.act.configuration_act import ACTConfig
assert isinstance(my_config, ACTConfig)
# assert isinstance(my_policy, ACTPolicy)
# Simulates a new policy type being registered
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
AutoPolicyConfig.register("pi0", PI0Config)
my_new_config = AutoPolicyConfig.for_policy("pi0")
assert isinstance(my_new_config,PI0Config)
# from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
# AutoPolicy.register(PI0Config,PI0Policy)
# Only new policies need registration
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
AutoPolicyConfig.register("diffusion", DiffusionConfig)
AutoPolicy.register(DiffusionConfig, DiffusionPolicy)
my_new_config = AutoPolicyConfig.for_policy("diffusion")
# my_new_policy = AutoPolicy.from_config(my_new_config)
# assert isinstance(my_new_policy,PI0Policy)
assert isinstance(my_new_config, DiffusionConfig)
# assert isinstance(my_new_policy, DiffusionPolicy)
if __name__ == "__main__":
main()
main()