Parse draccus subclass overrides when using --policy.path (#1501)

* Parse draccus subclass overrides when using --policy.path

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Ben Zhang
2025-07-15 03:29:07 -07:00
committed by GitHub
parent c4c0105a47
commit 1c0ac8e341

View File

@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import json
import logging
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Type, TypeVar
@@ -183,8 +185,22 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
# HACK: Parse the original config to get the config subclass, so that we can
# apply cli overrides.
# This is very ugly, ideally we'd like to be able to do that natively with draccus
# something like --policy.path (in addition to --policy.type)
cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_overrides)
orig_config = draccus.parse(cls, config_file, args=[])
with open(config_file) as f:
config = json.load(f)
config.pop("type")
with tempfile.NamedTemporaryFile("w+") as f:
json.dump(config, f)
config_file = f.name
f.flush()
cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)