Fix typehint and address the mypy errors of src/lerobot/configs (#1746)

* fix: update policy handling and type annotations
added typehint and addressed the error of mypy

* fix: rename should_push_to_hub to push_to_hub
I find that there are other dependencies of push_to_hub so I fix the property name back to original one.

* fix: typo

* fix: changed the position of try-except block
As the copilot said, use raise before `hf_hub_download` would stop program even it is able to download

* fix: update pre-commit configuration and mypy settings
add args: --follow-imports=silent to pass error which have no relationship with src/lerobot/configs

* fix: remove the specific path in .pre-commit-config.yaml

* feat: enhance typehint to adapt mypy strict mode.

* fix: remove duplicate FileNotFoundError check in PreTrainedConfig

* fix: make "pre-commit run --all-files" pass

* fix: replace logging with logger for better logging practices

* fix: fixed extra changes of lint and  format changes

* fix: fixed extra changes out of "configs" module

* Update src/lerobot/configs/policies.py

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: tetsugo02 <131431116+tetsugo02@users.noreply.github.com>

* fix: add logging for scratch job

---------

Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
Signed-off-by: tetsugo02 <131431116+tetsugo02@users.noreply.github.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
tetsugo02
2025-10-20 19:57:32 +09:00
committed by GitHub
parent c54cd529a2
commit eff8a6fd12
7 changed files with 90 additions and 50 deletions

View File

@@ -289,9 +289,19 @@ ignore_errors = false
# module = "lerobot.utils.*" # module = "lerobot.utils.*"
# ignore_errors = false # ignore_errors = false
# [[tool.mypy.overrides]] [[tool.mypy.overrides]]
# module = "lerobot.configs.*" module = "lerobot.configs.*"
# ignore_errors = false ignore_errors = false
# extra strictness for configs
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
# ignores
disable_error_code = ["attr-defined"] #TODO: draccus issue
# include = "src/lerobot/configs/**/*.py"
# [[tool.mypy.overrides]] # [[tool.mypy.overrides]]
# module = "lerobot.optim.*" # module = "lerobot.optim.*"

View File

@@ -57,7 +57,7 @@ class EvalConfig:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing). # `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: bool = False use_async_envs: bool = False
def __post_init__(self): def __post_init__(self) -> None:
if self.batch_size > self.n_episodes: if self.batch_size > self.n_episodes:
raise ValueError( raise ValueError(
"The eval batch size is greater than the number of eval episodes " "The eval batch size is greater than the number of eval episodes "

View File

@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
import datetime as dt import datetime as dt
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path from pathlib import Path
from lerobot import envs, policies # noqa: F401 from lerobot import envs, policies # noqa: F401
@@ -22,6 +22,8 @@ from lerobot.configs import parser
from lerobot.configs.default import EvalConfig from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
logger = getLogger(__name__)
@dataclass @dataclass
class EvalPipelineConfig: class EvalPipelineConfig:
@@ -35,24 +37,28 @@ class EvalPipelineConfig:
job_name: str | None = None job_name: str | None = None
seed: int | None = 1000 seed: int | None = 1000
def __post_init__(self): def __post_init__(self) -> None:
# HACK: We parse again the cli args here to get the pretrained path if there was one. # HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy") policy_path = parser.get_path_arg("policy")
if policy_path: if policy_path:
cli_overrides = parser.get_cli_overrides("policy") cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path self.policy.pretrained_path = Path(policy_path)
else: else:
logging.warning( logger.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)." "No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
) )
if not self.job_name: if not self.job_name:
if self.env is None: if self.env is None:
self.job_name = f"{self.policy.type}" self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
else: else:
self.job_name = f"{self.env.type}_{self.policy.type}" self.job_name = (
f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
)
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
if not self.output_dir: if not self.output_dir:
now = dt.datetime.now() now = dt.datetime.now()

View File

@@ -16,14 +16,19 @@ import inspect
import pkgutil import pkgutil
import sys import sys
from argparse import ArgumentError from argparse import ArgumentError
from collections.abc import Sequence from collections.abc import Callable, Iterable, Sequence
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from pkgutil import ModuleInfo
from types import ModuleType
from typing import Any, TypeVar, cast
import draccus import draccus
from lerobot.utils.utils import has_method from lerobot.utils.utils import has_method
F = TypeVar("F", bound=Callable[..., object])
PATH_KEY = "path" PATH_KEY = "path"
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path" PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
@@ -60,7 +65,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
return None return None
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict: def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
"""Parse plugin-related arguments from command-line arguments. """Parse plugin-related arguments from command-line arguments.
This function extracts arguments from command-line arguments that match a specified suffix pattern. This function extracts arguments from command-line arguments that match a specified suffix pattern.
@@ -127,7 +132,7 @@ def load_plugin(plugin_path: str) -> None:
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}" f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
) from e ) from e
def iter_namespace(ns_pkg): def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".") return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
try: try:
@@ -148,6 +153,8 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]: def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
if args is None:
return []
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")] return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
@@ -171,7 +178,8 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
if isinstance(fields_to_filter, str): if isinstance(fields_to_filter, str):
fields_to_filter = [fields_to_filter] fields_to_filter = [fields_to_filter]
filtered_args = args filtered_args = [] if args is None else list(args)
for field in fields_to_filter: for field in fields_to_filter:
if get_path_arg(field, args): if get_path_arg(field, args):
if get_type_arg(field, args): if get_type_arg(field, args):
@@ -184,7 +192,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
return filtered_args return filtered_args
def wrap(config_path: Path | None = None): def wrap(config_path: Path | None = None) -> Callable[[F], F]:
""" """
HACK: Similar to draccus.wrap but does three additional things: HACK: Similar to draccus.wrap but does three additional things:
- Will remove '.path' arguments from CLI in order to process them later on. - Will remove '.path' arguments from CLI in order to process them later on.
@@ -195,9 +203,9 @@ def wrap(config_path: Path | None = None):
from the CLI '.type' arguments from the CLI '.type' arguments
""" """
def wrapper_outer(fn): def wrapper_outer(fn: F) -> F:
@wraps(fn) @wraps(fn)
def wrapper_inner(*args, **kwargs): def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
argspec = inspect.getfullargspec(fn) argspec = inspect.getfullargspec(fn)
argtype = argspec.annotations[argspec.args[0]] argtype = argspec.annotations[argspec.args[0]]
if len(args) > 0 and type(args[0]) is argtype: if len(args) > 0 and type(args[0]) is argtype:
@@ -225,6 +233,6 @@ def wrap(config_path: Path | None = None):
response = fn(cfg, *args, **kwargs) response = fn(cfg, *args, **kwargs)
return response return response
return wrapper_inner return cast(F, wrapper_inner)
return wrapper_outer return cast(Callable[[F], F], wrapper_outer)

View File

@@ -14,12 +14,12 @@
import abc import abc
import builtins import builtins
import json import json
import logging
import os import os
import tempfile import tempfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import TypeVar from typing import Any, TypeVar
import draccus import draccus
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@@ -34,10 +34,11 @@ from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
T = TypeVar("T", bound="PreTrainedConfig") T = TypeVar("T", bound="PreTrainedConfig")
logger = getLogger(__name__)
@dataclass @dataclass
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
""" """
Base configuration class for policy models. Base configuration class for policy models.
@@ -62,7 +63,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
# automatic gradient scaling is used. # automatic gradient scaling is used.
use_amp: bool = False use_amp: bool = False
push_to_hub: bool = True push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
repo_id: str | None = None repo_id: str | None = None
# Upload on private repository on the Hugging Face hub. # Upload on private repository on the Hugging Face hub.
@@ -73,38 +74,41 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
license: str | None = None license: str | None = None
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights # Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch. # saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: str | None = None pretrained_path: Path | None = None
def __post_init__(self): def __post_init__(self) -> None:
if not self.device or not is_torch_device_available(self.device): if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device() auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type self.device = auto_device.type
# Automatically deactivate AMP if necessary # Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device): if self.use_amp and not is_amp_available(self.device):
logging.warning( logger.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
) )
self.use_amp = False self.use_amp = False
@property @property
def type(self) -> str: def type(self) -> str:
return self.get_choice_name(self.__class__) choice_name = self.get_choice_name(self.__class__)
if not isinstance(choice_name, str):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
@property @property
@abc.abstractmethod @abc.abstractmethod
def observation_delta_indices(self) -> list | None: def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
def action_delta_indices(self) -> list | None: def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
def reward_delta_indices(self) -> list | None: def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
@@ -154,13 +158,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
pretrained_name_or_path: str | Path, pretrained_name_or_path: str | Path,
*, *,
force_download: bool = False, force_download: bool = False,
resume_download: bool = None, resume_download: bool | None = None,
proxies: dict | None = None, proxies: dict[Any, Any] | None = None,
token: str | bool | None = None, token: str | bool | None = None,
cache_dir: str | Path | None = None, cache_dir: str | Path | None = None,
local_files_only: bool = False, local_files_only: bool = False,
revision: str | None = None, revision: str | None = None,
**policy_kwargs, **policy_kwargs: Any,
) -> T: ) -> T:
model_id = str(pretrained_name_or_path) model_id = str(pretrained_name_or_path)
config_file: str | None = None config_file: str | None = None
@@ -168,7 +172,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
if CONFIG_NAME in os.listdir(model_id): if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME) config_file = os.path.join(model_id, CONFIG_NAME)
else: else:
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else: else:
try: try:
config_file = hf_hub_download( config_file = hf_hub_download(
@@ -194,6 +198,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
with draccus.config_type("json"): with draccus.config_type("json"):
orig_config = draccus.parse(cls, config_file, args=[]) orig_config = draccus.parse(cls, config_file, args=[])
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
with open(config_file) as f: with open(config_file) as f:
config = json.load(f) config = json.load(f)

View File

@@ -16,6 +16,7 @@ import datetime as dt
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any
import draccus import draccus
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@@ -63,18 +64,16 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig) eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig) wandb: WandBConfig = field(default_factory=WandBConfig)
checkpoint_path: Path | None = field(init=False, default=None)
def __post_init__(self): def validate(self) -> None:
self.checkpoint_path = None
def validate(self):
# HACK: We parse again the cli args here to get the pretrained paths if there was some. # HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy") policy_path = parser.get_path_arg("policy")
if policy_path: if policy_path:
# Only load the policy config # Only load the policy config
cli_overrides = parser.get_cli_overrides("policy") cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path self.policy.pretrained_path = Path(policy_path)
elif self.resume: elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir # The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path") config_path = parser.parse_arg("config_path")
@@ -82,14 +81,22 @@ class TrainPipelineConfig(HubMixin):
raise ValueError( raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}" f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
) )
if not Path(config_path).resolve().exists(): if not Path(config_path).resolve().exists():
raise NotADirectoryError( raise NotADirectoryError(
f"{config_path=} is expected to be a local path. " f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now." "Resuming from the hub is not supported for now."
) )
policy_path = Path(config_path).parent
self.policy.pretrained_path = policy_path policy_dir = Path(config_path).parent
self.checkpoint_path = policy_path.parent if self.policy is not None:
self.policy.pretrained_path = policy_dir
self.checkpoint_path = policy_dir.parent
if self.policy is None:
raise ValueError(
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
)
if not self.job_name: if not self.job_name:
if self.env is None: if self.env is None:
@@ -126,8 +133,8 @@ class TrainPipelineConfig(HubMixin):
"""This enables the parser to load config from the policy using `--policy.path=local/dir`""" """This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"] return ["policy"]
def to_dict(self) -> dict: def to_dict(self) -> dict[str, Any]:
return draccus.encode(self) return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
def _save_pretrained(self, save_directory: Path) -> None: def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"): with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
@@ -139,13 +146,13 @@ class TrainPipelineConfig(HubMixin):
pretrained_name_or_path: str | Path, pretrained_name_or_path: str | Path,
*, *,
force_download: bool = False, force_download: bool = False,
resume_download: bool = None, resume_download: bool | None = None,
proxies: dict | None = None, proxies: dict[Any, Any] | None = None,
token: str | bool | None = None, token: str | bool | None = None,
cache_dir: str | Path | None = None, cache_dir: str | Path | None = None,
local_files_only: bool = False, local_files_only: bool = False,
revision: str | None = None, revision: str | None = None,
**kwargs, **kwargs: Any,
) -> "TrainPipelineConfig": ) -> "TrainPipelineConfig":
model_id = str(pretrained_name_or_path) model_id = str(pretrained_name_or_path)
config_file: str | None = None config_file: str | None = None
@@ -181,4 +188,6 @@ class TrainPipelineConfig(HubMixin):
@dataclass(kw_only=True) @dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig): class TrainRLServerPipelineConfig(TrainPipelineConfig):
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset # NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional

View File

@@ -42,4 +42,4 @@ class NormalizationMode(str, Enum):
@dataclass @dataclass
class PolicyFeature: class PolicyFeature:
type: FeatureType type: FeatureType
shape: tuple shape: tuple[int, ...]