[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
761a2dbcb3
commit
8e6d5f504c
@@ -29,7 +29,9 @@ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
draccus.set_config_type("json")
|
||||
|
||||
|
||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||
def get_cli_overrides(
|
||||
field_name: str, args: Sequence[str] | None = None
|
||||
) -> list[str] | None:
|
||||
"""Parses arguments from cli at a given nested attribute level.
|
||||
|
||||
For example, supposing the main script was called with:
|
||||
@@ -42,7 +44,10 @@ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> lis
|
||||
args = sys.argv[1:]
|
||||
attr_level_args = []
|
||||
detect_string = f"--{field_name}."
|
||||
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
|
||||
exclude_strings = (
|
||||
f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=",
|
||||
f"--{field_name}.{PATH_KEY}=",
|
||||
)
|
||||
for arg in args:
|
||||
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
|
||||
denested_arg = f"--{arg.removeprefix(detect_string)}"
|
||||
@@ -153,7 +158,9 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[
|
||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
|
||||
def filter_path_args(
|
||||
fields_to_filter: str | list[str], args: Sequence[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Filters command-line arguments related to fields with specific path arguments.
|
||||
|
||||
@@ -181,7 +188,9 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
argument=None,
|
||||
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
|
||||
)
|
||||
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
|
||||
filtered_args = [
|
||||
arg for arg in filtered_args if not arg.startswith(f"--{field}.")
|
||||
]
|
||||
|
||||
return filtered_args
|
||||
|
||||
@@ -213,7 +222,9 @@ def wrap(config_path: Path | None = None):
|
||||
load_plugin(plugin_path)
|
||||
except PluginLoadError as e:
|
||||
# add the relevant CLI arg to the error message
|
||||
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
|
||||
raise PluginLoadError(
|
||||
f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}"
|
||||
) from e
|
||||
cli_args = filter_arg(plugin_cli_arg, cli_args)
|
||||
config_path_cli = parse_arg("config_path", cli_args)
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
@@ -223,7 +234,9 @@ def wrap(config_path: Path | None = None):
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype, config_path=config_path, args=cli_args
|
||||
)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
Reference in New Issue
Block a user