Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
125
lerobot/configs/parser.py
Normal file
125
lerobot/configs/parser.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import inspect
|
||||
import sys
|
||||
from argparse import ArgumentError
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.utils.utils import has_method
|
||||
|
||||
PATH_KEY = "path"
|
||||
draccus.set_config_type("json")
|
||||
|
||||
|
||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||
"""Parses arguments from cli at a given nested attribute level.
|
||||
|
||||
For example, supposing the main script was called with:
|
||||
python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path
|
||||
|
||||
If called during execution of myscript.py, get_cli_overrides("arg2") will return:
|
||||
["--subarg1=abc" "--subarg2=some/path"]
|
||||
"""
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
attr_level_args = []
|
||||
detect_string = f"--{field_name}."
|
||||
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
|
||||
for arg in args:
|
||||
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
|
||||
denested_arg = f"--{arg.removeprefix(detect_string)}"
|
||||
attr_level_args.append(denested_arg)
|
||||
|
||||
return attr_level_args
|
||||
|
||||
|
||||
def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
prefix = f"--{arg_name}="
|
||||
for arg in args:
|
||||
if arg.startswith(prefix):
|
||||
return arg[len(prefix) :]
|
||||
return None
|
||||
|
||||
|
||||
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
|
||||
|
||||
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args)
|
||||
|
||||
|
||||
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
|
||||
"""
|
||||
Filters command-line arguments related to fields with specific path arguments.
|
||||
|
||||
Args:
|
||||
fields_to_filter (str | list[str]): A single str or a list of str whose arguments need to be filtered.
|
||||
args (Sequence[str] | None): The sequence of command-line arguments to be filtered.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[str]: A filtered list of arguments, with arguments related to the specified
|
||||
fields removed.
|
||||
|
||||
Raises:
|
||||
ArgumentError: If both a path argument (e.g., `--field_name.path`) and a type
|
||||
argument (e.g., `--field_name.type`) are specified for the same field.
|
||||
"""
|
||||
if isinstance(fields_to_filter, str):
|
||||
fields_to_filter = [fields_to_filter]
|
||||
|
||||
filtered_args = args
|
||||
for field in fields_to_filter:
|
||||
if get_path_arg(field, args):
|
||||
if get_type_arg(field, args):
|
||||
raise ArgumentError(
|
||||
argument=None,
|
||||
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
|
||||
)
|
||||
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
|
||||
|
||||
return filtered_args
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None):
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does two additional things:
|
||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
|
||||
initialize it from there to allow to fetch configs from the hub directly
|
||||
"""
|
||||
|
||||
def wrapper_outer(fn):
|
||||
@wraps(fn)
|
||||
def wrapper_inner(*args, **kwargs):
|
||||
argspec = inspect.getfullargspec(fn)
|
||||
argtype = argspec.annotations[argspec.args[0]]
|
||||
if len(args) > 0 and type(args[0]) is argtype:
|
||||
cfg = args[0]
|
||||
args = args[1:]
|
||||
else:
|
||||
cli_args = sys.argv[1:]
|
||||
config_path_cli = parse_arg("config_path", cli_args)
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
path_fields = argtype.__get_path_fields__()
|
||||
cli_args = filter_path_args(path_fields, cli_args)
|
||||
if has_method(argtype, "from_pretrained") and config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
return wrapper_inner
|
||||
|
||||
return wrapper_outer
|
||||
Reference in New Issue
Block a user