Parse draccus subclass overrides when using --policy.path (#1501)
* Parse draccus subclass overrides when using --policy.path * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -12,8 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Type, TypeVar
|
from typing import Type, TypeVar
|
||||||
@@ -183,8 +185,22 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
|
# HACK: Parse the original config to get the config subclass, so that we can
|
||||||
|
# apply cli overrides.
|
||||||
|
# This is very ugly, ideally we'd like to be able to do that natively with draccus
|
||||||
# something like --policy.path (in addition to --policy.type)
|
# something like --policy.path (in addition to --policy.type)
|
||||||
cli_overrides = policy_kwargs.pop("cli_overrides", [])
|
|
||||||
with draccus.config_type("json"):
|
with draccus.config_type("json"):
|
||||||
return draccus.parse(cls, config_file, args=cli_overrides)
|
orig_config = draccus.parse(cls, config_file, args=[])
|
||||||
|
|
||||||
|
with open(config_file) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
config.pop("type")
|
||||||
|
with tempfile.NamedTemporaryFile("w+") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
config_file = f.name
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
cli_overrides = policy_kwargs.pop("cli_overrides", [])
|
||||||
|
with draccus.config_type("json"):
|
||||||
|
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
|
||||||
|
|||||||
Reference in New Issue
Block a user