diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py index 022d1fb5..7caac957 100644 --- a/lerobot/configs/policies.py +++ b/lerobot/configs/policies.py @@ -34,7 +34,9 @@ T = TypeVar("T", bound="PreTrainedConfig") @dataclass -class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): +class PreTrainedConfig( + draccus.PluginRegistry, HubMixin, abc.ABC, discover_packages_path="lerobot.common.policies" +): """ Base configuration class for policy models. @@ -174,3 +176,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # something like --policy.path (in addition to --policy.type) cli_overrides = policy_kwargs.pop("cli_overrides", []) return draccus.parse(cls, config_file, args=cli_overrides) + + @classmethod + def register(cls, config_type: str, config: "PreTrainedConfig", exist_ok: bool = False): + """Register a new configuration for this class.""" + if config_type in cls._choice_registry and not exist_ok: + raise ValueError( + f"'{config_type}' is already used by a {cls.__name__}, please pick another name." + ) + + cls._choice_registry[config_type] = config