[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by AdilZouitine
parent 761a2dbcb3
commit 8e6d5f504c
97 changed files with 1596 additions and 492 deletions

View File

@@ -29,7 +29,9 @@ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
draccus.set_config_type("json")
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
def get_cli_overrides(
field_name: str, args: Sequence[str] | None = None
) -> list[str] | None:
"""Parses arguments from cli at a given nested attribute level.
For example, supposing the main script was called with:
@@ -42,7 +44,10 @@ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> lis
args = sys.argv[1:]
attr_level_args = []
detect_string = f"--{field_name}."
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
exclude_strings = (
f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=",
f"--{field_name}.{PATH_KEY}=",
)
for arg in args:
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
denested_arg = f"--{arg.removeprefix(detect_string)}"
@@ -153,7 +158,9 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
def filter_path_args(
fields_to_filter: str | list[str], args: Sequence[str] | None = None
) -> list[str]:
"""
Filters command-line arguments related to fields with specific path arguments.
@@ -181,7 +188,9 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
argument=None,
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
)
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
filtered_args = [
arg for arg in filtered_args if not arg.startswith(f"--{field}.")
]
return filtered_args
@@ -213,7 +222,9 @@ def wrap(config_path: Path | None = None):
load_plugin(plugin_path)
except PluginLoadError as e:
# add the relevant CLI arg to the error message
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
raise PluginLoadError(
f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}"
) from e
cli_args = filter_arg(plugin_cli_arg, cli_args)
config_path_cli = parse_arg("config_path", cli_args)
if has_method(argtype, "__get_path_fields__"):
@@ -223,7 +234,9 @@ def wrap(config_path: Path | None = None):
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
cfg = draccus.parse(
config_class=argtype, config_path=config_path, args=cli_args
)
response = fn(cfg, *args, **kwargs)
return response