forked from tangger/lerobot
chore(autopolicyconfig): test main + format
This commit is contained in:
@@ -12,11 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
import importlib
|
import importlib
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
import os
|
||||||
from typing import Type, Union, Dict, Optional
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
|
||||||
POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
|
POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
@@ -24,31 +26,10 @@ POLICY_CONFIG_NAMES_MAPPING = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
POLICY_NAMES_MAPPING = OrderedDict(
|
|
||||||
[
|
|
||||||
("act", "ACTPolicy"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def model_type_to_module_name(model_type: str) -> str:
|
def policy_type_to_module_name(policy_type: str) -> str:
|
||||||
"""Convert model type to module name format."""
|
"""Convert policy type to module name format."""
|
||||||
return model_type.replace("-", "_")
|
return policy_type.replace("-", "_")
|
||||||
|
|
||||||
def find_policy_type_from_config(config_dict: dict) -> str:
|
|
||||||
"""Find the policy type from a config dictionary."""
|
|
||||||
if "type" in config_dict:
|
|
||||||
return config_dict["type"]
|
|
||||||
|
|
||||||
# Fallback: try to infer from class name
|
|
||||||
if "policy_class" in config_dict:
|
|
||||||
for policy_type, class_name in POLICY_CONFIG_NAMES_MAPPING.items():
|
|
||||||
if class_name in config_dict["policy_class"]:
|
|
||||||
return policy_type
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
"Could not determine policy type from config. "
|
|
||||||
"Config must contain either 'type' or 'policy_class' field."
|
|
||||||
)
|
|
||||||
|
|
||||||
class _LazyPolicyConfigMapping(OrderedDict):
|
class _LazyPolicyConfigMapping(OrderedDict):
|
||||||
"""
|
"""
|
||||||
@@ -65,10 +46,10 @@ class _LazyPolicyConfigMapping(OrderedDict):
|
|||||||
return self._extra_content[key]
|
return self._extra_content[key]
|
||||||
if key not in self._mapping:
|
if key not in self._mapping:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
|
|
||||||
value = self._mapping[key]
|
value = self._mapping[key]
|
||||||
module_name = model_type_to_module_name(key)
|
module_name = policy_type_to_module_name(key)
|
||||||
|
|
||||||
# Try standard import path first
|
# Try standard import path first
|
||||||
try:
|
try:
|
||||||
if key not in self._modules:
|
if key not in self._modules:
|
||||||
@@ -88,19 +69,17 @@ class _LazyPolicyConfigMapping(OrderedDict):
|
|||||||
return getattr(self._modules[key], value)
|
return getattr(self._modules[key], value)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raise ImportError(
|
raise ImportError(f"Could not find configuration class {value} for policy type {key}")
|
||||||
f"Could not find configuration class {value} for policy type {key}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return list(self._mapping.keys()) + list(self._extra_content.keys())
|
return list(self._mapping.keys()) + list(self._extra_content.keys())
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
|
return [self[k] for k in self._mapping] + list(self._extra_content.values())
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
|
return [(k, self[k]) for k in self._mapping] + list(self._extra_content.items())
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||||
@@ -112,35 +91,32 @@ class _LazyPolicyConfigMapping(OrderedDict):
|
|||||||
"""
|
"""
|
||||||
Register a new configuration in this mapping.
|
Register a new configuration in this mapping.
|
||||||
"""
|
"""
|
||||||
if key in self._mapping.keys() and not exist_ok:
|
if key in self._mapping and not exist_ok:
|
||||||
raise ValueError(f"'{key}' is already used by a Policy Config, pick another name.")
|
raise ValueError(f"'{key}' is already used by a Policy Config, pick another name.")
|
||||||
self._extra_content[key] = value
|
self._extra_content[key] = value
|
||||||
|
|
||||||
|
|
||||||
POLICY_CONFIG_MAPPING = _LazyPolicyConfigMapping(POLICY_CONFIG_NAMES_MAPPING)
|
POLICY_CONFIG_MAPPING = _LazyPolicyConfigMapping(POLICY_CONFIG_NAMES_MAPPING)
|
||||||
|
|
||||||
|
|
||||||
class AutoPolicyConfig:
|
class AutoPolicyConfig:
|
||||||
"""
|
"""
|
||||||
Factory class for automatically loading policy configurations.
|
Factory class for automatically loading policy configurations.
|
||||||
|
|
||||||
This class provides methods to:
|
This class provides methods to:
|
||||||
- Load pre-trained policy configurations from local files or the Hub
|
- Load pre-trained policy configurations from local files or the Hub
|
||||||
- Register new policy types dynamically
|
- Register new policy types dynamically
|
||||||
- Create policy configurations for specific policy types
|
- Create policy configurations for specific policy types
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise EnvironmentError(
|
raise OSError(
|
||||||
"AutoPolicyConfig is designed to be instantiated "
|
"AutoPolicyConfig is designed to be instantiated "
|
||||||
"using the `AutoPolicyConfig.from_pretrained(TODO)` method."
|
"using the `AutoPolicyConfig.from_pretrained(TODO)` method."
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_policy(
|
def for_policy(cls, policy_type: str, *args, **kwargs) -> PreTrainedConfig:
|
||||||
cls,
|
|
||||||
policy_type: str,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
) -> PreTrainedConfig:
|
|
||||||
"""Create a new configuration instance for the specified policy type."""
|
"""Create a new configuration instance for the specified policy type."""
|
||||||
if policy_type in POLICY_CONFIG_MAPPING:
|
if policy_type in POLICY_CONFIG_MAPPING:
|
||||||
config_class = POLICY_CONFIG_MAPPING[policy_type]
|
config_class = POLICY_CONFIG_MAPPING[policy_type]
|
||||||
@@ -148,7 +124,7 @@ class AutoPolicyConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized policy identifier: {policy_type}. Should contain one of {', '.join(POLICY_CONFIG_MAPPING.keys())}"
|
f"Unrecognized policy identifier: {policy_type}. Should contain one of {', '.join(POLICY_CONFIG_MAPPING.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def register(policy_type, config, exist_ok=False):
|
def register(policy_type, config, exist_ok=False):
|
||||||
"""
|
"""
|
||||||
@@ -166,27 +142,9 @@ class AutoPolicyConfig:
|
|||||||
)
|
)
|
||||||
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
|
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def register_policy_type(policy_type: str, config_class: str, policy_class: str) -> None:
|
|
||||||
"""
|
|
||||||
Register a new policy type with its configuration and policy classes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
policy_type (`str`): The policy type identifier (e.g., 'act')
|
|
||||||
config_class (`str`): Name of the configuration class
|
|
||||||
policy_class (`str`): Name of the policy class
|
|
||||||
"""
|
|
||||||
if policy_type in POLICY_CONFIG_NAMES_MAPPING:
|
|
||||||
raise ValueError(f"Policy type {policy_type} is already registered")
|
|
||||||
|
|
||||||
POLICY_CONFIG_NAMES_MAPPING[policy_type] = config_class
|
|
||||||
POLICY_NAMES_MAPPING[policy_type] = policy_class
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls, pretrained_policy_config_name_or_path: Union[str, Path], **kwargs
|
||||||
pretrained_policy_config_name_or_path: Union[str, Path],
|
|
||||||
**kwargs
|
|
||||||
) -> PreTrainedConfig:
|
) -> PreTrainedConfig:
|
||||||
"""
|
"""
|
||||||
Instantiate a PreTrainedConfig from a pre-trained policy configuration.
|
Instantiate a PreTrainedConfig from a pre-trained policy configuration.
|
||||||
@@ -194,9 +152,9 @@ class AutoPolicyConfig:
|
|||||||
Args:
|
Args:
|
||||||
pretrained_policy_config_name_or_path (`str` or `Path`):
|
pretrained_policy_config_name_or_path (`str` or `Path`):
|
||||||
Can be either:
|
Can be either:
|
||||||
- A string with the `policy_type` of a pre-trained policy configuration listed on
|
- A string with the `policy_type` of a pre-trained policy configuration listed on
|
||||||
the Hub or locally (e.g., 'act')
|
the Hub or locally (e.g., 'act')
|
||||||
- A path to a `directory` containing a configuration file saved
|
- A path to a `directory` containing a configuration file saved
|
||||||
using [`~PreTrainedConfig.save_pretrained`].
|
using [`~PreTrainedConfig.save_pretrained`].
|
||||||
- A path or url to a saved configuration JSON `file`.
|
- A path or url to a saved configuration JSON `file`.
|
||||||
**kwargs: Additional kwargs passed to PreTrainedConfig.from_pretrained()
|
**kwargs: Additional kwargs passed to PreTrainedConfig.from_pretrained()
|
||||||
@@ -215,13 +173,35 @@ class AutoPolicyConfig:
|
|||||||
else:
|
else:
|
||||||
# Assume it's a policy_type identifier
|
# Assume it's a policy_type identifier
|
||||||
policy_type = pretrained_policy_config_name_or_path
|
policy_type = pretrained_policy_config_name_or_path
|
||||||
|
|
||||||
if policy_type not in POLICY_CONFIG_MAPPING:
|
if policy_type not in POLICY_CONFIG_MAPPING:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized policy type {policy_type}. "
|
f"Unrecognized policy type {policy_type}. "
|
||||||
f"Should be one of {', '.join(POLICY_CONFIG_MAPPING.keys())}"
|
f"Should be one of {', '.join(POLICY_CONFIG_MAPPING.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
config_class = POLICY_CONFIG_MAPPING[policy_type]
|
config_class = POLICY_CONFIG_MAPPING[policy_type]
|
||||||
return config_class.from_pretrained(pretrained_policy_config_name_or_path, **kwargs)
|
return config_class.from_pretrained(pretrained_policy_config_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Simulates a standard policy type being loaded
|
||||||
|
my_config = AutoPolicyConfig.for_policy("act")
|
||||||
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
|
assert isinstance(my_config,ACTConfig)
|
||||||
|
# my_policy = AutoPolicy.from_config(my_config)
|
||||||
|
# from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||||
|
# assert isinstance(my_policy,ACTPolicy)
|
||||||
|
|
||||||
|
# Simulates a new policy type being registered
|
||||||
|
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||||
|
AutoPolicyConfig.register("pi0", PI0Config)
|
||||||
|
my_new_config = AutoPolicyConfig.for_policy("pi0")
|
||||||
|
assert isinstance(my_new_config,PI0Config)
|
||||||
|
# from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||||
|
# AutoPolicy.register(PI0Config,PI0Policy)
|
||||||
|
# my_new_policy = AutoPolicy.from_config(my_new_config)
|
||||||
|
# assert isinstance(my_new_policy,PI0Policy)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user