fix(autopolicy): draft v0.3

This commit is contained in:
Steven Palma
2025-03-06 12:05:46 +01:00
parent 8f98672ecc
commit caadc887ad

View File

@@ -24,8 +24,8 @@ from lerobot.configs.policies import PreTrainedConfig
def policy_type_to_module_name(policy_type: str) -> str: def policy_type_to_module_name(policy_type: str) -> str:
"""Convert policy type to module name format.""" """Convert policy type to module name format."""
return policy_type.replace("-", "_") # TODO(Steven): Deal with this
return policy_type.replace("lerobot/", "").replace("-", "_").replace("_", "").replace("pusht", "")
class _LazyPolicyConfigMapping(OrderedDict): class _LazyPolicyConfigMapping(OrderedDict):
""" """
@@ -49,6 +49,7 @@ class _LazyPolicyConfigMapping(OrderedDict):
# Try standard import path first # Try standard import path first
try: try:
if key not in self._modules: if key not in self._modules:
print("Importing CONFIG: ",module_name)
self._modules[key] = importlib.import_module( self._modules[key] = importlib.import_module(
f"lerobot.common.policies.{module_name}.configuration_{module_name}" f"lerobot.common.policies.{module_name}.configuration_{module_name}"
) )
@@ -60,6 +61,7 @@ class _LazyPolicyConfigMapping(OrderedDict):
f"lerobot.common.policies.{module_name}", f"lerobot.common.policies.{module_name}",
]: ]:
try: try:
print("Importing CONFIG: ",module_name)
self._modules[key] = importlib.import_module(import_path) self._modules[key] = importlib.import_module(import_path)
if hasattr(self._modules[key], value): if hasattr(self._modules[key], value):
return getattr(self._modules[key], value) return getattr(self._modules[key], value)
@@ -94,7 +96,8 @@ class _LazyPolicyConfigMapping(OrderedDict):
POLICY_CONFIG_NAMES_MAPPING = OrderedDict( POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
[ [
("act", "ACTConfig"), ("vqbet", "VQBeTConfig"),
("lerobot/vqbet_pusht", "VQBeTConfig"),
] ]
) )
@@ -111,16 +114,20 @@ class _LazyPolicyMapping(OrderedDict):
self._extra_content = {} self._extra_content = {}
self._modules = {} self._modules = {}
self._config_mapping = {} # Maps config classes to policy classes self._config_mapping = {} # Maps config classes to policy classes
self._initialized_types = set() # Track which types have been initialized
# Automatically set up mappings for built-in policies using POLICY_CONFIG_MAPPING def _lazy_init_for_type(self, policy_type: str) -> None:
for policy_type in self._mapping: """Lazily initialize mappings for a policy type if not already done."""
if policy_type not in self._initialized_types:
try: try:
config_class = POLICY_CONFIG_MAPPING[policy_type] config_class = POLICY_CONFIG_MAPPING[policy_type]
self._config_mapping[config_class] = self[policy_type] self._config_mapping[config_class] = self[policy_type]
self._initialized_types.add(policy_type)
except (ImportError, AttributeError, KeyError) as e: except (ImportError, AttributeError, KeyError) as e:
import logging import logging
logging.warning(
logging.warning(f"Could not automatically map config for policy type {policy_type}: {str(e)}") f"Could not automatically map config for policy type {policy_type}: {str(e)}"
)
def __getitem__(self, key): def __getitem__(self, key):
if key in self._extra_content: if key in self._extra_content:
@@ -133,6 +140,7 @@ class _LazyPolicyMapping(OrderedDict):
try: try:
if key not in self._modules: if key not in self._modules:
print("Importing POLICY: ", module_name)
self._modules[key] = importlib.import_module( self._modules[key] = importlib.import_module(
f"lerobot.common.policies.{module_name}.modeling_{module_name}" f"lerobot.common.policies.{module_name}.modeling_{module_name}"
) )
@@ -143,6 +151,7 @@ class _LazyPolicyMapping(OrderedDict):
f"lerobot.common.policies.{module_name}", f"lerobot.common.policies.{module_name}",
]: ]:
try: try:
print("Importing POLICY: ",module_name)
self._modules[key] = importlib.import_module(import_path) self._modules[key] = importlib.import_module(import_path)
if hasattr(self._modules[key], value): if hasattr(self._modules[key], value):
return getattr(self._modules[key], value) return getattr(self._modules[key], value)
@@ -164,22 +173,38 @@ class _LazyPolicyMapping(OrderedDict):
self._extra_content[key] = value self._extra_content[key] = value
self._config_mapping[config_class] = value self._config_mapping[config_class] = value
def get_policy_for_config(self, config_class: Type[PreTrainedConfig]) -> Type[PreTrainedPolicy]: def get_policy_for_config(self, config_class: PreTrainedConfig) -> Type[PreTrainedPolicy]:
"""Get the policy class associated with a config class.""" """Get the policy class associated with a config class."""
if config_class in self._config_mapping: # First check direct config class mapping
return self._config_mapping[config_class] if type(config_class) in self._config_mapping:
return self._config_mapping[type(config_class)]
# Try to find by policy type # Try to find by policy type
policy_type = config_class.type try:
if policy_type in self: policy_type = config_class.type
return self[policy_type] # Check extra content first
if policy_type in self._extra_content:
return self._extra_content[policy_type]
raise ValueError(f"No policy class found for config class {config_class.__name__}") # Then check standard mapping
if policy_type in self._mapping:
self._lazy_init_for_type(policy_type)
if config_class in self._config_mapping:
return self._config_mapping[config_class]
return self[policy_type]
except AttributeError:
pass
raise ValueError(
f"No policy class found for config class {config_class.__name__}. "
f"Available types: {list(self._mapping.keys()) + list(self._extra_content.keys())}"
)
POLICY_NAMES_MAPPING = OrderedDict( POLICY_NAMES_MAPPING = OrderedDict(
[ [
("act", "ACTPolicy"), ("vqbet", "VQBeTPolicy"),
("lerobot/vqbet_pusht", "VQBeTPolicy"),
] ]
) )
@@ -221,10 +246,11 @@ class AutoPolicyConfig:
policy_type (`str`): The policy type like "act" or "pi0". policy_type (`str`): The policy type like "act" or "pi0".
config ([`PreTrainedConfig`]): The config to register. config ([`PreTrainedConfig`]): The config to register.
""" """
# if issubclass(config, PreTrainedConfig) and config.policy_type != policy_type: # TODO(Steven): config.type doesn't work at this stage because it is not an instance, it the class definition
# if issubclass(config, PreTrainedConfig) and config.type != policy_type:
# raise ValueError( # raise ValueError(
# "The config you are passing has a `policy_type` attribute that is not consistent with the policy type " # "The config you are passing has a `policy_type` attribute that is not consistent with the policy type "
# f"you passed (config has {config.policy_type} and you passed {policy_type}. Fix one of those so they " # f"you passed (config has {config.type} and you passed {policy_type}. Fix one of those so they "
# "match!" # "match!"
# ) # )
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok) POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
@@ -290,7 +316,7 @@ class AutoPolicy:
@classmethod @classmethod
def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy: def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy:
"""Instantiate a policy from a configuration.""" """Instantiate a policy from a configuration."""
policy_class = POLICY_MAPPING.get_policy_for_config(config.__class__) policy_class = POLICY_MAPPING.get_policy_for_config(config)
return policy_class(config, **kwargs) return policy_class(config, **kwargs)
@classmethod @classmethod
@@ -315,7 +341,7 @@ class AutoPolicy:
if isinstance(config, str): if isinstance(config, str):
config = AutoPolicyConfig.from_pretrained(config) config = AutoPolicyConfig.from_pretrained(config)
policy_class = POLICY_MAPPING.get_policy_for_config(config.__class__) policy_class = POLICY_MAPPING.get_policy_for_config(config)
return policy_class.from_pretrained(pretrained_policy_name_or_path, config=config, **kwargs) return policy_class.from_pretrained(pretrained_policy_name_or_path, config=config, **kwargs)
@staticmethod @staticmethod
@@ -334,30 +360,32 @@ class AutoPolicy:
def main(): def main():
#TODO: Pass the needed arguments to the policies
# Simulates a build-in policy type being loaded # Simulates a build-in policy type being loaded
# Built-in policies work without explicit registration # Built-in policies work without explicit registration
my_config = AutoPolicyConfig.for_policy("act")
# my_policy = AutoPolicy.from_config(my_config)
from lerobot.common.policies.act.configuration_act import ACTConfig # config = AutoPolicyConfig.for_policy("vqbet")
config = AutoPolicyConfig.from_pretrained("lerobot/vqbet_pusht")
policy = AutoPolicy.from_config(config)
assert isinstance(my_config, ACTConfig) from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
# assert isinstance(my_policy, ACTPolicy) from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
assert isinstance(config, VQBeTConfig)
assert isinstance(policy, VQBeTPolicy)
# Simulates a new policy type being registered # Simulates a new policy type being registered
# Only new policies need registration # Only new policies need registration
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
AutoPolicyConfig.register("diffusion", DiffusionConfig) AutoPolicyConfig.register("tdmpc", TDMPCConfig)
AutoPolicy.register(DiffusionConfig, DiffusionPolicy) AutoPolicy.register(TDMPCConfig, TDMPCPolicy)
my_new_config = AutoPolicyConfig.for_policy("diffusion") my_new_config = AutoPolicyConfig.for_policy("tdmpc")
# my_new_policy = AutoPolicy.from_config(my_new_config) my_new_policy = AutoPolicy.from_config(my_new_config)
assert isinstance(my_new_config, DiffusionConfig) assert isinstance(my_new_config, TDMPCConfig)
# assert isinstance(my_new_policy, DiffusionPolicy) assert isinstance(my_new_policy, TDMPCPolicy)
if __name__ == "__main__": if __name__ == "__main__":