feat(autopolicy): draft v0.2
This commit is contained in:
@@ -16,21 +16,17 @@ import importlib
|
|||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Optional, Type, Union
|
||||||
|
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
|
||||||
POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
|
|
||||||
[
|
|
||||||
("act", "ACTConfig"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def policy_type_to_module_name(policy_type: str) -> str:
|
def policy_type_to_module_name(policy_type: str) -> str:
|
||||||
"""Convert policy type to module name format."""
|
"""Convert policy type to module name format."""
|
||||||
return policy_type.replace("-", "_")
|
return policy_type.replace("-", "_")
|
||||||
|
|
||||||
|
|
||||||
class _LazyPolicyConfigMapping(OrderedDict):
|
class _LazyPolicyConfigMapping(OrderedDict):
|
||||||
"""
|
"""
|
||||||
A dictionary that lazily load its values when they are requested.
|
A dictionary that lazily load its values when they are requested.
|
||||||
@@ -96,9 +92,100 @@ class _LazyPolicyConfigMapping(OrderedDict):
|
|||||||
self._extra_content[key] = value
|
self._extra_content[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
("act", "ACTConfig"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
POLICY_CONFIG_MAPPING = _LazyPolicyConfigMapping(POLICY_CONFIG_NAMES_MAPPING)
|
POLICY_CONFIG_MAPPING = _LazyPolicyConfigMapping(POLICY_CONFIG_NAMES_MAPPING)
|
||||||
|
|
||||||
|
|
||||||
|
class _LazyPolicyMapping(OrderedDict):
|
||||||
|
"""
|
||||||
|
A dictionary that lazily loads its values when they are requested.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mapping):
|
||||||
|
self._mapping = mapping
|
||||||
|
self._extra_content = {}
|
||||||
|
self._modules = {}
|
||||||
|
self._config_mapping = {} # Maps config classes to policy classes
|
||||||
|
|
||||||
|
# Automatically set up mappings for built-in policies using POLICY_CONFIG_MAPPING
|
||||||
|
for policy_type in self._mapping:
|
||||||
|
try:
|
||||||
|
config_class = POLICY_CONFIG_MAPPING[policy_type]
|
||||||
|
self._config_mapping[config_class] = self[policy_type]
|
||||||
|
except (ImportError, AttributeError, KeyError) as e:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.warning(f"Could not automatically map config for policy type {policy_type}: {str(e)}")
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if key in self._extra_content:
|
||||||
|
return self._extra_content[key]
|
||||||
|
if key not in self._mapping:
|
||||||
|
raise KeyError(key)
|
||||||
|
|
||||||
|
value = self._mapping[key]
|
||||||
|
module_name = policy_type_to_module_name(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if key not in self._modules:
|
||||||
|
self._modules[key] = importlib.import_module(
|
||||||
|
f"lerobot.common.policies.{module_name}.modeling_{module_name}"
|
||||||
|
)
|
||||||
|
return getattr(self._modules[key], value)
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
for import_path in [
|
||||||
|
f"lerobot.policies.{module_name}",
|
||||||
|
f"lerobot.common.policies.{module_name}",
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
self._modules[key] = importlib.import_module(import_path)
|
||||||
|
if hasattr(self._modules[key], value):
|
||||||
|
return getattr(self._modules[key], value)
|
||||||
|
except ImportError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ImportError(f"Could not find policy class {value} for policy type {key}")
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Type[PreTrainedPolicy],
|
||||||
|
config_class: Type[PreTrainedConfig],
|
||||||
|
exist_ok: bool = False,
|
||||||
|
):
|
||||||
|
"""Register a new policy class with its configuration class."""
|
||||||
|
if key in self._mapping and not exist_ok:
|
||||||
|
raise ValueError(f"'{key}' is already used by a Policy, pick another name.")
|
||||||
|
self._extra_content[key] = value
|
||||||
|
self._config_mapping[config_class] = value
|
||||||
|
|
||||||
|
def get_policy_for_config(self, config_class: Type[PreTrainedConfig]) -> Type[PreTrainedPolicy]:
|
||||||
|
"""Get the policy class associated with a config class."""
|
||||||
|
if config_class in self._config_mapping:
|
||||||
|
return self._config_mapping[config_class]
|
||||||
|
|
||||||
|
# Try to find by policy type
|
||||||
|
policy_type = config_class.type
|
||||||
|
if policy_type in self:
|
||||||
|
return self[policy_type]
|
||||||
|
|
||||||
|
raise ValueError(f"No policy class found for config class {config_class.__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
POLICY_NAMES_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
("act", "ACTPolicy"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
POLICY_MAPPING = _LazyPolicyMapping(POLICY_NAMES_MAPPING)
|
||||||
|
|
||||||
|
|
||||||
class AutoPolicyConfig:
|
class AutoPolicyConfig:
|
||||||
"""
|
"""
|
||||||
Factory class for automatically loading policy configurations.
|
Factory class for automatically loading policy configurations.
|
||||||
@@ -134,12 +221,12 @@ class AutoPolicyConfig:
|
|||||||
policy_type (`str`): The policy type like "act" or "pi0".
|
policy_type (`str`): The policy type like "act" or "pi0".
|
||||||
config ([`PreTrainedConfig`]): The config to register.
|
config ([`PreTrainedConfig`]): The config to register.
|
||||||
"""
|
"""
|
||||||
if issubclass(config, PreTrainedConfig) and config.policy_type != policy_type:
|
# if issubclass(config, PreTrainedConfig) and config.policy_type != policy_type:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
"The config you are passing has a `policy_type` attribute that is not consistent with the policy type "
|
# "The config you are passing has a `policy_type` attribute that is not consistent with the policy type "
|
||||||
f"you passed (config has {config.policy_type} and you passed {policy_type}. Fix one of those so they "
|
# f"you passed (config has {config.policy_type} and you passed {policy_type}. Fix one of those so they "
|
||||||
"match!"
|
# "match!"
|
||||||
)
|
# )
|
||||||
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
|
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -184,24 +271,94 @@ class AutoPolicyConfig:
|
|||||||
return config_class.from_pretrained(pretrained_policy_config_name_or_path, **kwargs)
|
return config_class.from_pretrained(pretrained_policy_config_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoPolicy:
|
||||||
|
"""
|
||||||
|
Factory class that allows instantiating policy models from configurations.
|
||||||
|
|
||||||
|
This class provides methods to:
|
||||||
|
- Load pre-trained policies from configurations
|
||||||
|
- Register new policy types dynamically
|
||||||
|
- Create policy instances for specific configurations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise OSError(
|
||||||
|
"AutoPolicy is designed to be instantiated using the "
|
||||||
|
"`AutoPolicy.from_config()` or `AutoPolicy.from_pretrained()` methods."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy:
|
||||||
|
"""Instantiate a policy from a configuration."""
|
||||||
|
policy_class = POLICY_MAPPING.get_policy_for_config(config.__class__)
|
||||||
|
return policy_class(config, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_policy_name_or_path: Union[str, Path],
|
||||||
|
*,
|
||||||
|
config: Optional[PreTrainedConfig] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> PreTrainedPolicy:
|
||||||
|
"""
|
||||||
|
Instantiate a pre-trained policy from a configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_policy_name_or_path: Path to pretrained weights or model identifier
|
||||||
|
config: Optional configuration for the policy
|
||||||
|
**kwargs: Additional arguments to pass to from_pretrained()
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = AutoPolicyConfig.from_pretrained(pretrained_policy_name_or_path)
|
||||||
|
|
||||||
|
if isinstance(config, str):
|
||||||
|
config = AutoPolicyConfig.from_pretrained(config)
|
||||||
|
|
||||||
|
policy_class = POLICY_MAPPING.get_policy_for_config(config.__class__)
|
||||||
|
return policy_class.from_pretrained(pretrained_policy_name_or_path, config=config, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register(
|
||||||
|
config_class: Type[PreTrainedConfig], policy_class: Type[PreTrainedPolicy], exist_ok: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Register a new policy class for a configuration class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_class: The configuration class
|
||||||
|
policy_class: The policy class to register
|
||||||
|
exist_ok: Whether to allow overwriting existing registrations
|
||||||
|
"""
|
||||||
|
POLICY_MAPPING.register(config_class.type, policy_class, config_class, exist_ok=exist_ok)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Simulates a standard policy type being loaded
|
#TODO: Pass the needed arguments to the policies
|
||||||
|
|
||||||
|
# Simulates a build-in policy type being loaded
|
||||||
|
# Built-in policies work without explicit registration
|
||||||
my_config = AutoPolicyConfig.for_policy("act")
|
my_config = AutoPolicyConfig.for_policy("act")
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
|
||||||
assert isinstance(my_config,ACTConfig)
|
|
||||||
# my_policy = AutoPolicy.from_config(my_config)
|
# my_policy = AutoPolicy.from_config(my_config)
|
||||||
# from lerobot.common.policies.act.modeling_act import ACTPolicy
|
|
||||||
# assert isinstance(my_policy,ACTPolicy)
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
|
|
||||||
|
assert isinstance(my_config, ACTConfig)
|
||||||
|
# assert isinstance(my_policy, ACTPolicy)
|
||||||
|
|
||||||
# Simulates a new policy type being registered
|
# Simulates a new policy type being registered
|
||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
# Only new policies need registration
|
||||||
AutoPolicyConfig.register("pi0", PI0Config)
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
my_new_config = AutoPolicyConfig.for_policy("pi0")
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
assert isinstance(my_new_config,PI0Config)
|
|
||||||
# from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
AutoPolicyConfig.register("diffusion", DiffusionConfig)
|
||||||
# AutoPolicy.register(PI0Config,PI0Policy)
|
AutoPolicy.register(DiffusionConfig, DiffusionPolicy)
|
||||||
|
|
||||||
|
my_new_config = AutoPolicyConfig.for_policy("diffusion")
|
||||||
# my_new_policy = AutoPolicy.from_config(my_new_config)
|
# my_new_policy = AutoPolicy.from_config(my_new_config)
|
||||||
# assert isinstance(my_new_policy,PI0Policy)
|
assert isinstance(my_new_config, DiffusionConfig)
|
||||||
|
# assert isinstance(my_new_policy, DiffusionPolicy)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
Reference in New Issue
Block a user