diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 36e6ea2e..05f3296b 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import json import logging import os +import tempfile from dataclasses import dataclass, field from pathlib import Path from typing import Type, TypeVar @@ -183,8 +185,22 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" ) from e - # HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus + # HACK: Parse the original config to get the config subclass, so that we can + # apply cli overrides. + # This is very ugly, ideally we'd like to be able to do that natively with draccus # something like --policy.path (in addition to --policy.type) - cli_overrides = policy_kwargs.pop("cli_overrides", []) with draccus.config_type("json"): - return draccus.parse(cls, config_file, args=cli_overrides) + orig_config = draccus.parse(cls, config_file, args=[]) + + with open(config_file) as f: + config = json.load(f) + + config.pop("type") + with tempfile.NamedTemporaryFile("w+") as f: + json.dump(config, f) + config_file = f.name + f.flush() + + cli_overrides = policy_kwargs.pop("cli_overrides", []) + with draccus.config_type("json"): + return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)