From 1c0ac8e3415adff7411846c73c6e9dfb94941eb1 Mon Sep 17 00:00:00 2001 From: Ben Zhang <5977478+ben-z@users.noreply.github.com> Date: Tue, 15 Jul 2025 03:29:07 -0700 Subject: [PATCH] Parse draccus subclass overrides when using `--policy.path` (#1501) * Parse draccus subclass overrides when using --policy.path * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- src/lerobot/configs/policies.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 36e6ea2e..05f3296b 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import json import logging import os +import tempfile from dataclasses import dataclass, field from pathlib import Path from typing import Type, TypeVar @@ -183,8 +185,22 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" ) from e - # HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus + # HACK: Parse the original config to get the config subclass, so that we can + # apply cli overrides. + # This is very ugly, ideally we'd like to be able to do that natively with draccus # something like --policy.path (in addition to --policy.type) - cli_overrides = policy_kwargs.pop("cli_overrides", []) with draccus.config_type("json"): - return draccus.parse(cls, config_file, args=cli_overrides) + orig_config = draccus.parse(cls, config_file, args=[]) + + with open(config_file) as f: + config = json.load(f) + + config.pop("type") + with tempfile.NamedTemporaryFile("w+") as f: + json.dump(config, f) + config_file = f.name + f.flush() + + cli_overrides = policy_kwargs.pop("cli_overrides", []) + with draccus.config_type("json"): + return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)