chore(autopolicyconfig): test main + format

This commit is contained in:
Steven Palma
2025-03-05 21:52:39 +01:00
parent 85099f45f4
commit 78df84f758

View File

@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import OrderedDict
import importlib import importlib
from lerobot.configs.policies import PreTrainedConfig import os
from typing import Type, Union, Dict, Optional from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Union
from lerobot.configs.policies import PreTrainedConfig
POLICY_CONFIG_NAMES_MAPPING = OrderedDict( POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
[ [
@@ -24,31 +26,10 @@ POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
] ]
) )
POLICY_NAMES_MAPPING = OrderedDict(
[
("act", "ACTPolicy"),
]
)
def model_type_to_module_name(model_type: str) -> str: def policy_type_to_module_name(policy_type: str) -> str:
"""Convert model type to module name format.""" """Convert policy type to module name format."""
return model_type.replace("-", "_") return policy_type.replace("-", "_")
def find_policy_type_from_config(config_dict: dict) -> str:
"""Find the policy type from a config dictionary."""
if "type" in config_dict:
return config_dict["type"]
# Fallback: try to infer from class name
if "policy_class" in config_dict:
for policy_type, class_name in POLICY_CONFIG_NAMES_MAPPING.items():
if class_name in config_dict["policy_class"]:
return policy_type
raise ValueError(
"Could not determine policy type from config. "
"Config must contain either 'type' or 'policy_class' field."
)
class _LazyPolicyConfigMapping(OrderedDict): class _LazyPolicyConfigMapping(OrderedDict):
""" """
@@ -67,7 +48,7 @@ class _LazyPolicyConfigMapping(OrderedDict):
raise KeyError(key) raise KeyError(key)
value = self._mapping[key] value = self._mapping[key]
module_name = model_type_to_module_name(key) module_name = policy_type_to_module_name(key)
# Try standard import path first # Try standard import path first
try: try:
@@ -89,18 +70,16 @@ class _LazyPolicyConfigMapping(OrderedDict):
except ImportError: except ImportError:
continue continue
raise ImportError( raise ImportError(f"Could not find configuration class {value} for policy type {key}")
f"Could not find configuration class {value} for policy type {key}"
)
def keys(self): def keys(self):
return list(self._mapping.keys()) + list(self._extra_content.keys()) return list(self._mapping.keys()) + list(self._extra_content.keys())
def values(self): def values(self):
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()) return [self[k] for k in self._mapping] + list(self._extra_content.values())
def items(self): def items(self):
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()) return [(k, self[k]) for k in self._mapping] + list(self._extra_content.items())
def __iter__(self): def __iter__(self):
return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
@@ -112,12 +91,14 @@ class _LazyPolicyConfigMapping(OrderedDict):
""" """
Register a new configuration in this mapping. Register a new configuration in this mapping.
""" """
if key in self._mapping.keys() and not exist_ok: if key in self._mapping and not exist_ok:
raise ValueError(f"'{key}' is already used by a Policy Config, pick another name.") raise ValueError(f"'{key}' is already used by a Policy Config, pick another name.")
self._extra_content[key] = value self._extra_content[key] = value
POLICY_CONFIG_MAPPING = _LazyPolicyConfigMapping(POLICY_CONFIG_NAMES_MAPPING) POLICY_CONFIG_MAPPING = _LazyPolicyConfigMapping(POLICY_CONFIG_NAMES_MAPPING)
class AutoPolicyConfig: class AutoPolicyConfig:
""" """
Factory class for automatically loading policy configurations. Factory class for automatically loading policy configurations.
@@ -129,18 +110,13 @@ class AutoPolicyConfig:
""" """
def __init__(self): def __init__(self):
raise EnvironmentError( raise OSError(
"AutoPolicyConfig is designed to be instantiated " "AutoPolicyConfig is designed to be instantiated "
"using the `AutoPolicyConfig.from_pretrained(TODO)` method." "using the `AutoPolicyConfig.from_pretrained(TODO)` method."
) )
@classmethod @classmethod
def for_policy( def for_policy(cls, policy_type: str, *args, **kwargs) -> PreTrainedConfig:
cls,
policy_type: str,
*args,
**kwargs
) -> PreTrainedConfig:
"""Create a new configuration instance for the specified policy type.""" """Create a new configuration instance for the specified policy type."""
if policy_type in POLICY_CONFIG_MAPPING: if policy_type in POLICY_CONFIG_MAPPING:
config_class = POLICY_CONFIG_MAPPING[policy_type] config_class = POLICY_CONFIG_MAPPING[policy_type]
@@ -166,27 +142,9 @@ class AutoPolicyConfig:
) )
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok) POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
@staticmethod
def register_policy_type(policy_type: str, config_class: str, policy_class: str) -> None:
"""
Register a new policy type with its configuration and policy classes.
Args:
policy_type (`str`): The policy type identifier (e.g., 'act')
config_class (`str`): Name of the configuration class
policy_class (`str`): Name of the policy class
"""
if policy_type in POLICY_CONFIG_NAMES_MAPPING:
raise ValueError(f"Policy type {policy_type} is already registered")
POLICY_CONFIG_NAMES_MAPPING[policy_type] = config_class
POLICY_NAMES_MAPPING[policy_type] = policy_class
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls, pretrained_policy_config_name_or_path: Union[str, Path], **kwargs
pretrained_policy_config_name_or_path: Union[str, Path],
**kwargs
) -> PreTrainedConfig: ) -> PreTrainedConfig:
""" """
Instantiate a PreTrainedConfig from a pre-trained policy configuration. Instantiate a PreTrainedConfig from a pre-trained policy configuration.
@@ -225,3 +183,25 @@ class AutoPolicyConfig:
config_class = POLICY_CONFIG_MAPPING[policy_type] config_class = POLICY_CONFIG_MAPPING[policy_type]
return config_class.from_pretrained(pretrained_policy_config_name_or_path, **kwargs) return config_class.from_pretrained(pretrained_policy_config_name_or_path, **kwargs)
def main():
# Simulates a standard policy type being loaded
my_config = AutoPolicyConfig.for_policy("act")
from lerobot.common.policies.act.configuration_act import ACTConfig
assert isinstance(my_config,ACTConfig)
# my_policy = AutoPolicy.from_config(my_config)
# from lerobot.common.policies.act.modeling_act import ACTPolicy
# assert isinstance(my_policy,ACTPolicy)
# Simulates a new policy type being registered
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
AutoPolicyConfig.register("pi0", PI0Config)
my_new_config = AutoPolicyConfig.for_policy("pi0")
assert isinstance(my_new_config,PI0Config)
# from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
# AutoPolicy.register(PI0Config,PI0Policy)
# my_new_policy = AutoPolicy.from_config(my_new_config)
# assert isinstance(my_new_policy,PI0Policy)
if __name__ == "__main__":
main()