Parse draccus subclass overrides when using --policy.path (#1501)

* Parse draccus subclass overrides when using --policy.path

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Ben Zhang
2025-07-15 03:29:07 -07:00
committed by GitHub
parent c4c0105a47
commit 1c0ac8e341

View File

@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc import abc
import json
import logging import logging
import os import os
import tempfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Type, TypeVar from typing import Type, TypeVar
@@ -183,8 +185,22 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e ) from e
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus # HACK: Parse the original config to get the config subclass, so that we can
# apply cli overrides.
# This is very ugly, ideally we'd like to be able to do that natively with draccus
# something like --policy.path (in addition to --policy.type) # something like --policy.path (in addition to --policy.type)
cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"): with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_overrides) orig_config = draccus.parse(cls, config_file, args=[])
with open(config_file) as f:
config = json.load(f)
config.pop("type")
with tempfile.NamedTemporaryFile("w+") as f:
json.dump(config, f)
config_file = f.name
f.flush()
cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)