chore(autopolicyconfig): test main + format

This commit is contained in:
Steven Palma
2025-03-05 21:52:39 +01:00
parent 85099f45f4
commit 78df84f758

View File

@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import importlib
from lerobot.configs.policies import PreTrainedConfig
from typing import Type, Union, Dict, Optional
import os
from collections import OrderedDict
from pathlib import Path
from typing import Union
from lerobot.configs.policies import PreTrainedConfig
POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
[
@@ -24,31 +26,10 @@ POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
]
)
POLICY_NAMES_MAPPING = OrderedDict(
[
("act", "ACTPolicy"),
]
)
def model_type_to_module_name(model_type: str) -> str:
"""Convert model type to module name format."""
return model_type.replace("-", "_")
def find_policy_type_from_config(config_dict: dict) -> str:
"""Find the policy type from a config dictionary."""
if "type" in config_dict:
return config_dict["type"]
# Fallback: try to infer from class name
if "policy_class" in config_dict:
for policy_type, class_name in POLICY_CONFIG_NAMES_MAPPING.items():
if class_name in config_dict["policy_class"]:
return policy_type
raise ValueError(
"Could not determine policy type from config. "
"Config must contain either 'type' or 'policy_class' field."
)
def policy_type_to_module_name(policy_type: str) -> str:
"""Convert policy type to module name format."""
return policy_type.replace("-", "_")
class _LazyPolicyConfigMapping(OrderedDict):
"""
@@ -65,10 +46,10 @@ class _LazyPolicyConfigMapping(OrderedDict):
return self._extra_content[key]
if key not in self._mapping:
raise KeyError(key)
value = self._mapping[key]
module_name = model_type_to_module_name(key)
module_name = policy_type_to_module_name(key)
# Try standard import path first
try:
if key not in self._modules:
@@ -88,19 +69,17 @@ class _LazyPolicyConfigMapping(OrderedDict):
return getattr(self._modules[key], value)
except ImportError:
continue
raise ImportError(
f"Could not find configuration class {value} for policy type {key}"
)
raise ImportError(f"Could not find configuration class {value} for policy type {key}")
def keys(self):
return list(self._mapping.keys()) + list(self._extra_content.keys())
def values(self):
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
return [self[k] for k in self._mapping] + list(self._extra_content.values())
def items(self):
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
return [(k, self[k]) for k in self._mapping] + list(self._extra_content.items())
def __iter__(self):
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
@@ -112,35 +91,32 @@ class _LazyPolicyConfigMapping(OrderedDict):
"""
Register a new configuration in this mapping.
"""
if key in self._mapping.keys() and not exist_ok:
if key in self._mapping and not exist_ok:
raise ValueError(f"'{key}' is already used by a Policy Config, pick another name.")
self._extra_content[key] = value
POLICY_CONFIG_MAPPING = _LazyPolicyConfigMapping(POLICY_CONFIG_NAMES_MAPPING)
class AutoPolicyConfig:
"""
Factory class for automatically loading policy configurations.
This class provides methods to:
- Load pre-trained policy configurations from local files or the Hub
- Register new policy types dynamically
- Create policy configurations for specific policy types
"""
def __init__(self):
raise EnvironmentError(
raise OSError(
"AutoPolicyConfig is designed to be instantiated "
"using the `AutoPolicyConfig.from_pretrained(TODO)` method."
)
@classmethod
def for_policy(
cls,
policy_type: str,
*args,
**kwargs
) -> PreTrainedConfig:
def for_policy(cls, policy_type: str, *args, **kwargs) -> PreTrainedConfig:
"""Create a new configuration instance for the specified policy type."""
if policy_type in POLICY_CONFIG_MAPPING:
config_class = POLICY_CONFIG_MAPPING[policy_type]
@@ -148,7 +124,7 @@ class AutoPolicyConfig:
raise ValueError(
f"Unrecognized policy identifier: {policy_type}. Should contain one of {', '.join(POLICY_CONFIG_MAPPING.keys())}"
)
@staticmethod
def register(policy_type, config, exist_ok=False):
"""
@@ -166,27 +142,9 @@ class AutoPolicyConfig:
)
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
@staticmethod
def register_policy_type(policy_type: str, config_class: str, policy_class: str) -> None:
"""
Register a new policy type with its configuration and policy classes.
Args:
policy_type (`str`): The policy type identifier (e.g., 'act')
config_class (`str`): Name of the configuration class
policy_class (`str`): Name of the policy class
"""
if policy_type in POLICY_CONFIG_NAMES_MAPPING:
raise ValueError(f"Policy type {policy_type} is already registered")
POLICY_CONFIG_NAMES_MAPPING[policy_type] = config_class
POLICY_NAMES_MAPPING[policy_type] = policy_class
@classmethod
def from_pretrained(
cls,
pretrained_policy_config_name_or_path: Union[str, Path],
**kwargs
cls, pretrained_policy_config_name_or_path: Union[str, Path], **kwargs
) -> PreTrainedConfig:
"""
Instantiate a PreTrainedConfig from a pre-trained policy configuration.
@@ -194,9 +152,9 @@ class AutoPolicyConfig:
Args:
pretrained_policy_config_name_or_path (`str` or `Path`):
Can be either:
- A string with the `policy_type` of a pre-trained policy configuration listed on
- A string with the `policy_type` of a pre-trained policy configuration listed on
the Hub or locally (e.g., 'act')
- A path to a `directory` containing a configuration file saved
- A path to a `directory` containing a configuration file saved
using [`~PreTrainedConfig.save_pretrained`].
- A path or url to a saved configuration JSON `file`.
**kwargs: Additional kwargs passed to PreTrainedConfig.from_pretrained()
@@ -215,13 +173,35 @@ class AutoPolicyConfig:
else:
# Assume it's a policy_type identifier
policy_type = pretrained_policy_config_name_or_path
if policy_type not in POLICY_CONFIG_MAPPING:
raise ValueError(
f"Unrecognized policy type {policy_type}. "
f"Should be one of {', '.join(POLICY_CONFIG_MAPPING.keys())}"
)
config_class = POLICY_CONFIG_MAPPING[policy_type]
return config_class.from_pretrained(pretrained_policy_config_name_or_path, **kwargs)
def main():
# Simulates a standard policy type being loaded
my_config = AutoPolicyConfig.for_policy("act")
from lerobot.common.policies.act.configuration_act import ACTConfig
assert isinstance(my_config,ACTConfig)
# my_policy = AutoPolicy.from_config(my_config)
# from lerobot.common.policies.act.modeling_act import ACTPolicy
# assert isinstance(my_policy,ACTPolicy)
# Simulates a new policy type being registered
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
AutoPolicyConfig.register("pi0", PI0Config)
my_new_config = AutoPolicyConfig.for_policy("pi0")
assert isinstance(my_new_config,PI0Config)
# from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
# AutoPolicy.register(PI0Config,PI0Policy)
# my_new_policy = AutoPolicy.from_config(my_new_config)
# assert isinstance(my_new_policy,PI0Policy)
if __name__ == "__main__":
main()