From caadc887ad2aa61f6429790a46f87f7a223b4ebb Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 6 Mar 2025 12:05:46 +0100 Subject: [PATCH] fix(autopolicy): draft v0.3 --- .../policies/auto/configuration_auto.py | 94 ++++++++++++------- 1 file changed, 61 insertions(+), 33 deletions(-) diff --git a/lerobot/common/policies/auto/configuration_auto.py b/lerobot/common/policies/auto/configuration_auto.py index 649d6e19..8bb52284 100644 --- a/lerobot/common/policies/auto/configuration_auto.py +++ b/lerobot/common/policies/auto/configuration_auto.py @@ -24,8 +24,8 @@ from lerobot.configs.policies import PreTrainedConfig def policy_type_to_module_name(policy_type: str) -> str: """Convert policy type to module name format.""" - return policy_type.replace("-", "_") - + # TODO(Steven): Deal with this + return policy_type.replace("lerobot/", "").replace("-", "_").replace("_", "").replace("pusht", "") class _LazyPolicyConfigMapping(OrderedDict): """ @@ -49,6 +49,7 @@ class _LazyPolicyConfigMapping(OrderedDict): # Try standard import path first try: if key not in self._modules: + print("Importing CONFIG: ",module_name) self._modules[key] = importlib.import_module( f"lerobot.common.policies.{module_name}.configuration_{module_name}" ) @@ -60,6 +61,7 @@ class _LazyPolicyConfigMapping(OrderedDict): f"lerobot.common.policies.{module_name}", ]: try: + print("Importing CONFIG: ",module_name) self._modules[key] = importlib.import_module(import_path) if hasattr(self._modules[key], value): return getattr(self._modules[key], value) @@ -94,7 +96,8 @@ class _LazyPolicyConfigMapping(OrderedDict): POLICY_CONFIG_NAMES_MAPPING = OrderedDict( [ - ("act", "ACTConfig"), + ("vqbet", "VQBeTConfig"), + ("lerobot/vqbet_pusht", "VQBeTConfig"), ] ) @@ -111,16 +114,20 @@ class _LazyPolicyMapping(OrderedDict): self._extra_content = {} self._modules = {} self._config_mapping = {} # Maps config classes to policy classes + self._initialized_types = set() # Track which types have been initialized - # Automatically set up mappings for built-in policies using POLICY_CONFIG_MAPPING - for policy_type in self._mapping: + def _lazy_init_for_type(self, policy_type: str) -> None: + """Lazily initialize mappings for a policy type if not already done.""" + if policy_type not in self._initialized_types: try: config_class = POLICY_CONFIG_MAPPING[policy_type] self._config_mapping[config_class] = self[policy_type] + self._initialized_types.add(policy_type) except (ImportError, AttributeError, KeyError) as e: import logging - - logging.warning(f"Could not automatically map config for policy type {policy_type}: {str(e)}") + logging.warning( + f"Could not automatically map config for policy type {policy_type}: {str(e)}" + ) def __getitem__(self, key): if key in self._extra_content: @@ -133,6 +140,7 @@ class _LazyPolicyMapping(OrderedDict): try: if key not in self._modules: + print("Importing POLICY: ", module_name) self._modules[key] = importlib.import_module( f"lerobot.common.policies.{module_name}.modeling_{module_name}" ) @@ -143,6 +151,7 @@ class _LazyPolicyMapping(OrderedDict): f"lerobot.common.policies.{module_name}", ]: try: + print("Importing POLICY: ",module_name) self._modules[key] = importlib.import_module(import_path) if hasattr(self._modules[key], value): return getattr(self._modules[key], value) @@ -164,22 +173,38 @@ class _LazyPolicyMapping(OrderedDict): self._extra_content[key] = value self._config_mapping[config_class] = value - def get_policy_for_config(self, config_class: Type[PreTrainedConfig]) -> Type[PreTrainedPolicy]: + def get_policy_for_config(self, config_class: PreTrainedConfig) -> Type[PreTrainedPolicy]: """Get the policy class associated with a config class.""" - if config_class in self._config_mapping: - return self._config_mapping[config_class] + # First check direct config class mapping + if type(config_class) in self._config_mapping: + return self._config_mapping[type(config_class)] # Try to find by policy type - policy_type = config_class.type - if policy_type in self: - return self[policy_type] + try: + policy_type = config_class.type + # Check extra content first + if policy_type in self._extra_content: + return self._extra_content[policy_type] + + # Then check standard mapping + if policy_type in self._mapping: + self._lazy_init_for_type(policy_type) + if config_class in self._config_mapping: + return self._config_mapping[config_class] + return self[policy_type] + except AttributeError: + pass - raise ValueError(f"No policy class found for config class {config_class.__name__}") + raise ValueError( + f"No policy class found for config class {config_class.__name__}. " + f"Available types: {list(self._mapping.keys()) + list(self._extra_content.keys())}" + ) POLICY_NAMES_MAPPING = OrderedDict( [ - ("act", "ACTPolicy"), + ("vqbet", "VQBeTPolicy"), + ("lerobot/vqbet_pusht", "VQBeTPolicy"), ] ) @@ -221,10 +246,11 @@ class AutoPolicyConfig: policy_type (`str`): The policy type like "act" or "pi0". config ([`PreTrainedConfig`]): The config to register. """ - # if issubclass(config, PreTrainedConfig) and config.policy_type != policy_type: + # TODO(Steven): config.type doesn't work at this stage because it is not an instance, it the class definition + # if issubclass(config, PreTrainedConfig) and config.type != policy_type: # raise ValueError( # "The config you are passing has a `policy_type` attribute that is not consistent with the policy type " - # f"you passed (config has {config.policy_type} and you passed {policy_type}. Fix one of those so they " + # f"you passed (config has {config.type} and you passed {policy_type}. Fix one of those so they " # "match!" # ) POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok) @@ -290,7 +316,7 @@ class AutoPolicy: @classmethod def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy: """Instantiate a policy from a configuration.""" - policy_class = POLICY_MAPPING.get_policy_for_config(config.__class__) + policy_class = POLICY_MAPPING.get_policy_for_config(config) return policy_class(config, **kwargs) @classmethod @@ -315,7 +341,7 @@ class AutoPolicy: if isinstance(config, str): config = AutoPolicyConfig.from_pretrained(config) - policy_class = POLICY_MAPPING.get_policy_for_config(config.__class__) + policy_class = POLICY_MAPPING.get_policy_for_config(config) return policy_class.from_pretrained(pretrained_policy_name_or_path, config=config, **kwargs) @staticmethod @@ -334,30 +360,32 @@ class AutoPolicy: def main(): - #TODO: Pass the needed arguments to the policies # Simulates a build-in policy type being loaded # Built-in policies work without explicit registration - my_config = AutoPolicyConfig.for_policy("act") - # my_policy = AutoPolicy.from_config(my_config) - from lerobot.common.policies.act.configuration_act import ACTConfig + # config = AutoPolicyConfig.for_policy("vqbet") + config = AutoPolicyConfig.from_pretrained("lerobot/vqbet_pusht") + policy = AutoPolicy.from_config(config) - assert isinstance(my_config, ACTConfig) - # assert isinstance(my_policy, ACTPolicy) + from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig + from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy + + assert isinstance(config, VQBeTConfig) + assert isinstance(policy, VQBeTPolicy) # Simulates a new policy type being registered # Only new policies need registration - from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig - from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy + from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig + from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy - AutoPolicyConfig.register("diffusion", DiffusionConfig) - AutoPolicy.register(DiffusionConfig, DiffusionPolicy) + AutoPolicyConfig.register("tdmpc", TDMPCConfig) + AutoPolicy.register(TDMPCConfig, TDMPCPolicy) - my_new_config = AutoPolicyConfig.for_policy("diffusion") - # my_new_policy = AutoPolicy.from_config(my_new_config) - assert isinstance(my_new_config, DiffusionConfig) - # assert isinstance(my_new_policy, DiffusionPolicy) + my_new_config = AutoPolicyConfig.for_policy("tdmpc") + my_new_policy = AutoPolicy.from_config(my_new_config) + assert isinstance(my_new_config, TDMPCConfig) + assert isinstance(my_new_policy, TDMPCPolicy) if __name__ == "__main__":