Fix typehint and address the mypy errors of src/lerobot/configs (#1746)
* fix: update policy handling and type annotations added typehint and addressed the error of mypy * fix: rename should_push_to_hub to push_to_hub I find that there are other dependencies of push_to_hub so I fix the property name back to original one. * fix: typo * fix: changed the position of try-except block As the copilot said, use raise before `hf_hub_download` would stop program even it is able to download * fix: update pre-commit configuration and mypy settings add args: --follow-imports=silent to pass error which have no relationship with src/lerobot/configs * fix: remove the specific path in .pre-commit-config.yaml * feat: enhance typehint to adapt mypy strict mode. * fix: remove duplicate FileNotFoundError check in PreTrainedConfig * fix: make "pre-commit run --all-files" pass * fix: replace logging with logger for better logging practices * fix: fixed extra changes of lint and format changes * fix: fixed extra changes out of "configs" module * Update src/lerobot/configs/policies.py Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Signed-off-by: tetsugo02 <131431116+tetsugo02@users.noreply.github.com> * fix: add logging for scratch job --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Signed-off-by: tetsugo02 <131431116+tetsugo02@users.noreply.github.com> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -289,9 +289,19 @@ ignore_errors = false
|
||||
# module = "lerobot.utils.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.configs.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.configs.*"
|
||||
ignore_errors = false
|
||||
|
||||
# extra strictness for configs
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
|
||||
# ignores
|
||||
disable_error_code = ["attr-defined"] #TODO: draccus issue
|
||||
|
||||
# include = "src/lerobot/configs/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
|
||||
@@ -57,7 +57,7 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.batch_size > self.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import datetime as dt
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import envs, policies # noqa: F401
|
||||
@@ -22,6 +22,8 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.default import EvalConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalPipelineConfig:
|
||||
@@ -35,24 +37,28 @@ class EvalPipelineConfig:
|
||||
job_name: str | None = None
|
||||
seed: int | None = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
|
||||
else:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{self.policy.type}"
|
||||
self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
self.job_name = (
|
||||
f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
|
||||
)
|
||||
|
||||
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
|
||||
|
||||
if not self.output_dir:
|
||||
now = dt.datetime.now()
|
||||
|
||||
@@ -16,14 +16,19 @@ import inspect
|
||||
import pkgutil
|
||||
import sys
|
||||
from argparse import ArgumentError
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from pkgutil import ModuleInfo
|
||||
from types import ModuleType
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.utils.utils import has_method
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
|
||||
@@ -60,7 +65,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
|
||||
"""Parse plugin-related arguments from command-line arguments.
|
||||
|
||||
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
||||
@@ -127,7 +132,7 @@ def load_plugin(plugin_path: str) -> None:
|
||||
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||
) from e
|
||||
|
||||
def iter_namespace(ns_pkg):
|
||||
def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
|
||||
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
||||
|
||||
try:
|
||||
@@ -148,6 +153,8 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
|
||||
|
||||
|
||||
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
||||
if args is None:
|
||||
return []
|
||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
@@ -171,7 +178,8 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
if isinstance(fields_to_filter, str):
|
||||
fields_to_filter = [fields_to_filter]
|
||||
|
||||
filtered_args = args
|
||||
filtered_args = [] if args is None else list(args)
|
||||
|
||||
for field in fields_to_filter:
|
||||
if get_path_arg(field, args):
|
||||
if get_type_arg(field, args):
|
||||
@@ -184,7 +192,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
return filtered_args
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None):
|
||||
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||
@@ -195,9 +203,9 @@ def wrap(config_path: Path | None = None):
|
||||
from the CLI '.type' arguments
|
||||
"""
|
||||
|
||||
def wrapper_outer(fn):
|
||||
def wrapper_outer(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper_inner(*args, **kwargs):
|
||||
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
argspec = inspect.getfullargspec(fn)
|
||||
argtype = argspec.annotations[argspec.args[0]]
|
||||
if len(args) > 0 and type(args[0]) is argtype:
|
||||
@@ -225,6 +233,6 @@ def wrap(config_path: Path | None = None):
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
return wrapper_inner
|
||||
return cast(F, wrapper_inner)
|
||||
|
||||
return wrapper_outer
|
||||
return cast(Callable[[F], F], wrapper_outer)
|
||||
|
||||
@@ -14,12 +14,12 @@
|
||||
import abc
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -34,10 +34,11 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedConfig")
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
|
||||
"""
|
||||
Base configuration class for policy models.
|
||||
|
||||
@@ -62,7 +63,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
|
||||
push_to_hub: bool = True
|
||||
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
|
||||
repo_id: str | None = None
|
||||
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
@@ -73,38 +74,41 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
license: str | None = None
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: str | None = None
|
||||
pretrained_path: Path | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None:
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -154,13 +158,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**policy_kwargs,
|
||||
**policy_kwargs: Any,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
@@ -168,7 +172,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
@@ -194,6 +198,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
with draccus.config_type("json"):
|
||||
orig_config = draccus.parse(cls, config_file, args=[])
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import datetime as dt
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -63,18 +64,16 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
self.checkpoint_path = None
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
@@ -82,14 +81,22 @@ class TrainPipelineConfig(HubMixin):
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
"Resuming from the hub is not supported for now."
|
||||
)
|
||||
policy_path = Path(config_path).parent
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.checkpoint_path = policy_path.parent
|
||||
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
if self.policy is None:
|
||||
raise ValueError(
|
||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
@@ -126,8 +133,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return draccus.encode(self)
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
@@ -139,13 +146,13 @@ class TrainPipelineConfig(HubMixin):
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> "TrainPipelineConfig":
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
@@ -181,4 +188,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
@@ -42,4 +42,4 @@ class NormalizationMode(str, Enum):
|
||||
@dataclass
|
||||
class PolicyFeature:
|
||||
type: FeatureType
|
||||
shape: tuple
|
||||
shape: tuple[int, ...]
|
||||
|
||||
Reference in New Issue
Block a user