diff --git a/pyproject.toml b/pyproject.toml index 25039503..36746e79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -289,9 +289,19 @@ ignore_errors = false # module = "lerobot.utils.*" # ignore_errors = false -# [[tool.mypy.overrides]] -# module = "lerobot.configs.*" -# ignore_errors = false +[[tool.mypy.overrides]] +module = "lerobot.configs.*" +ignore_errors = false + +# extra strictness for configs +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true + +# ignores +disable_error_code = ["attr-defined"] #TODO: draccus issue + +# include = "src/lerobot/configs/**/*.py" # [[tool.mypy.overrides]] # module = "lerobot.optim.*" diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index afd644e1..630d63f1 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -57,7 +57,7 @@ class EvalConfig: # `use_async_envs` specifies whether to use asynchronous environments (multiprocessing). use_async_envs: bool = False - def __post_init__(self): + def __post_init__(self) -> None: if self.batch_size > self.n_episodes: raise ValueError( "The eval batch size is greater than the number of eval episodes " diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py index cfe48cf8..e9e05a7e 100644 --- a/src/lerobot/configs/eval.py +++ b/src/lerobot/configs/eval.py @@ -13,8 +13,8 @@ # limitations under the License. import datetime as dt -import logging from dataclasses import dataclass, field +from logging import getLogger from pathlib import Path from lerobot import envs, policies # noqa: F401 @@ -22,6 +22,8 @@ from lerobot.configs import parser from lerobot.configs.default import EvalConfig from lerobot.configs.policies import PreTrainedConfig +logger = getLogger(__name__) + @dataclass class EvalPipelineConfig: @@ -35,24 +37,28 @@ class EvalPipelineConfig: job_name: str | None = None seed: int | None = 1000 - def __post_init__(self): + def __post_init__(self) -> None: # HACK: We parse again the cli args here to get the pretrained path if there was one. policy_path = parser.get_path_arg("policy") if policy_path: cli_overrides = parser.get_cli_overrides("policy") self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path + self.policy.pretrained_path = Path(policy_path) else: - logging.warning( + logger.warning( "No pretrained path was provided, evaluated policy will be built from scratch (random weights)." ) if not self.job_name: if self.env is None: - self.job_name = f"{self.policy.type}" + self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}" else: - self.job_name = f"{self.env.type}_{self.policy.type}" + self.job_name = ( + f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}" + ) + + logger.warning(f"No job name provided, using '{self.job_name}' as job name.") if not self.output_dir: now = dt.datetime.now() diff --git a/src/lerobot/configs/parser.py b/src/lerobot/configs/parser.py index 2296eaa2..57ebaf8f 100644 --- a/src/lerobot/configs/parser.py +++ b/src/lerobot/configs/parser.py @@ -16,14 +16,19 @@ import inspect import pkgutil import sys from argparse import ArgumentError -from collections.abc import Sequence +from collections.abc import Callable, Iterable, Sequence from functools import wraps from pathlib import Path +from pkgutil import ModuleInfo +from types import ModuleType +from typing import Any, TypeVar, cast import draccus from lerobot.utils.utils import has_method +F = TypeVar("F", bound=Callable[..., object]) + PATH_KEY = "path" PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path" @@ -60,7 +65,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None: return None -def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict: +def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]: """Parse plugin-related arguments from command-line arguments. This function extracts arguments from command-line arguments that match a specified suffix pattern. @@ -127,7 +132,7 @@ def load_plugin(plugin_path: str) -> None: f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}" ) from e - def iter_namespace(ns_pkg): + def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]: return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".") try: @@ -148,6 +153,8 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]: + if args is None: + return [] return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")] @@ -171,7 +178,8 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No if isinstance(fields_to_filter, str): fields_to_filter = [fields_to_filter] - filtered_args = args + filtered_args = [] if args is None else list(args) + for field in fields_to_filter: if get_path_arg(field, args): if get_type_arg(field, args): @@ -184,7 +192,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No return filtered_args -def wrap(config_path: Path | None = None): +def wrap(config_path: Path | None = None) -> Callable[[F], F]: """ HACK: Similar to draccus.wrap but does three additional things: - Will remove '.path' arguments from CLI in order to process them later on. @@ -195,9 +203,9 @@ def wrap(config_path: Path | None = None): from the CLI '.type' arguments """ - def wrapper_outer(fn): + def wrapper_outer(fn: F) -> F: @wraps(fn) - def wrapper_inner(*args, **kwargs): + def wrapper_inner(*args: Any, **kwargs: Any) -> Any: argspec = inspect.getfullargspec(fn) argtype = argspec.annotations[argspec.args[0]] if len(args) > 0 and type(args[0]) is argtype: @@ -225,6 +233,6 @@ def wrap(config_path: Path | None = None): response = fn(cfg, *args, **kwargs) return response - return wrapper_inner + return cast(F, wrapper_inner) - return wrapper_outer + return cast(Callable[[F], F], wrapper_outer) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 98dd4df3..b1cc19a4 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -14,12 +14,12 @@ import abc import builtins import json -import logging import os import tempfile from dataclasses import dataclass, field +from logging import getLogger from pathlib import Path -from typing import TypeVar +from typing import Any, TypeVar import draccus from huggingface_hub import hf_hub_download @@ -34,10 +34,11 @@ from lerobot.utils.hub import HubMixin from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available T = TypeVar("T", bound="PreTrainedConfig") +logger = getLogger(__name__) @dataclass -class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): +class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue """ Base configuration class for policy models. @@ -62,7 +63,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # automatic gradient scaling is used. use_amp: bool = False - push_to_hub: bool = True + push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override repo_id: str | None = None # Upload on private repository on the Hugging Face hub. @@ -73,38 +74,41 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): license: str | None = None # Either the repo ID of a model hosted on the Hub or a path to a directory containing weights # saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch. - pretrained_path: str | None = None + pretrained_path: Path | None = None - def __post_init__(self): + def __post_init__(self) -> None: if not self.device or not is_torch_device_available(self.device): auto_device = auto_select_torch_device() - logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") + logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") self.device = auto_device.type # Automatically deactivate AMP if necessary if self.use_amp and not is_amp_available(self.device): - logging.warning( + logger.warning( f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." ) self.use_amp = False @property def type(self) -> str: - return self.get_choice_name(self.__class__) + choice_name = self.get_choice_name(self.__class__) + if not isinstance(choice_name, str): + raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}") + return choice_name @property @abc.abstractmethod - def observation_delta_indices(self) -> list | None: + def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation raise NotImplementedError @property @abc.abstractmethod - def action_delta_indices(self) -> list | None: + def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation raise NotImplementedError @property @abc.abstractmethod - def reward_delta_indices(self) -> list | None: + def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation raise NotImplementedError @abc.abstractmethod @@ -154,13 +158,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): pretrained_name_or_path: str | Path, *, force_download: bool = False, - resume_download: bool = None, - proxies: dict | None = None, + resume_download: bool | None = None, + proxies: dict[Any, Any] | None = None, token: str | bool | None = None, cache_dir: str | Path | None = None, local_files_only: bool = False, revision: str | None = None, - **policy_kwargs, + **policy_kwargs: Any, ) -> T: model_id = str(pretrained_name_or_path) config_file: str | None = None @@ -168,7 +172,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): if CONFIG_NAME in os.listdir(model_id): config_file = os.path.join(model_id, CONFIG_NAME) else: - print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") + logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") else: try: config_file = hf_hub_download( @@ -194,6 +198,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): with draccus.config_type("json"): orig_config = draccus.parse(cls, config_file, args=[]) + if config_file is None: + raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}") + with open(config_file) as f: config = json.load(f) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 60a4d81d..2f3a65db 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -16,6 +16,7 @@ import datetime as dt import os from dataclasses import dataclass, field from pathlib import Path +from typing import Any import draccus from huggingface_hub import hf_hub_download @@ -63,18 +64,16 @@ class TrainPipelineConfig(HubMixin): scheduler: LRSchedulerConfig | None = None eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) + checkpoint_path: Path | None = field(init=False, default=None) - def __post_init__(self): - self.checkpoint_path = None - - def validate(self): + def validate(self) -> None: # HACK: We parse again the cli args here to get the pretrained paths if there was some. policy_path = parser.get_path_arg("policy") if policy_path: # Only load the policy config cli_overrides = parser.get_cli_overrides("policy") self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path + self.policy.pretrained_path = Path(policy_path) elif self.resume: # The entire train config is already loaded, we just need to get the checkpoint dir config_path = parser.parse_arg("config_path") @@ -82,14 +81,22 @@ class TrainPipelineConfig(HubMixin): raise ValueError( f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}" ) + if not Path(config_path).resolve().exists(): raise NotADirectoryError( f"{config_path=} is expected to be a local path. " "Resuming from the hub is not supported for now." ) - policy_path = Path(config_path).parent - self.policy.pretrained_path = policy_path - self.checkpoint_path = policy_path.parent + + policy_dir = Path(config_path).parent + if self.policy is not None: + self.policy.pretrained_path = policy_dir + self.checkpoint_path = policy_dir.parent + + if self.policy is None: + raise ValueError( + "Policy is not configured. Please specify a pretrained policy with `--policy.path`." + ) if not self.job_name: if self.env is None: @@ -126,8 +133,8 @@ class TrainPipelineConfig(HubMixin): """This enables the parser to load config from the policy using `--policy.path=local/dir`""" return ["policy"] - def to_dict(self) -> dict: - return draccus.encode(self) + def to_dict(self) -> dict[str, Any]: + return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type def _save_pretrained(self, save_directory: Path) -> None: with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"): @@ -139,13 +146,13 @@ class TrainPipelineConfig(HubMixin): pretrained_name_or_path: str | Path, *, force_download: bool = False, - resume_download: bool = None, - proxies: dict | None = None, + resume_download: bool | None = None, + proxies: dict[Any, Any] | None = None, token: str | bool | None = None, cache_dir: str | Path | None = None, local_files_only: bool = False, revision: str | None = None, - **kwargs, + **kwargs: Any, ) -> "TrainPipelineConfig": model_id = str(pretrained_name_or_path) config_file: str | None = None @@ -181,4 +188,6 @@ class TrainPipelineConfig(HubMixin): @dataclass(kw_only=True) class TrainRLServerPipelineConfig(TrainPipelineConfig): - dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset + # NOTE: In RL, we don't need an offline dataset + # TODO: Make `TrainPipelineConfig.dataset` optional + dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index cb578060..11a1f8d7 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -42,4 +42,4 @@ class NormalizationMode(str, Enum): @dataclass class PolicyFeature: type: FeatureType - shape: tuple + shape: tuple[int, ...]