Fix typehint and address the mypy errors of src/lerobot/configs (#1746)
* fix: update policy handling and type annotations added typehint and addressed the error of mypy * fix: rename should_push_to_hub to push_to_hub I find that there are other dependencies of push_to_hub so I fix the property name back to original one. * fix: typo * fix: changed the position of try-except block As the copilot said, use raise before `hf_hub_download` would stop program even it is able to download * fix: update pre-commit configuration and mypy settings add args: --follow-imports=silent to pass error which have no relationship with src/lerobot/configs * fix: remove the specific path in .pre-commit-config.yaml * feat: enhance typehint to adapt mypy strict mode. * fix: remove duplicate FileNotFoundError check in PreTrainedConfig * fix: make "pre-commit run --all-files" pass * fix: replace logging with logger for better logging practices * fix: fixed extra changes of lint and format changes * fix: fixed extra changes out of "configs" module * Update src/lerobot/configs/policies.py Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Signed-off-by: tetsugo02 <131431116+tetsugo02@users.noreply.github.com> * fix: add logging for scratch job --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Signed-off-by: tetsugo02 <131431116+tetsugo02@users.noreply.github.com> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -289,9 +289,19 @@ ignore_errors = false
|
|||||||
# module = "lerobot.utils.*"
|
# module = "lerobot.utils.*"
|
||||||
# ignore_errors = false
|
# ignore_errors = false
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
# module = "lerobot.configs.*"
|
module = "lerobot.configs.*"
|
||||||
# ignore_errors = false
|
ignore_errors = false
|
||||||
|
|
||||||
|
# extra strictness for configs
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
disallow_incomplete_defs = true
|
||||||
|
check_untyped_defs = true
|
||||||
|
|
||||||
|
# ignores
|
||||||
|
disable_error_code = ["attr-defined"] #TODO: draccus issue
|
||||||
|
|
||||||
|
# include = "src/lerobot/configs/**/*.py"
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
# [[tool.mypy.overrides]]
|
||||||
# module = "lerobot.optim.*"
|
# module = "lerobot.optim.*"
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class EvalConfig:
|
|||||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||||
use_async_envs: bool = False
|
use_async_envs: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
if self.batch_size > self.n_episodes:
|
if self.batch_size > self.n_episodes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The eval batch size is greater than the number of eval episodes "
|
"The eval batch size is greater than the number of eval episodes "
|
||||||
|
|||||||
@@ -13,8 +13,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lerobot import envs, policies # noqa: F401
|
from lerobot import envs, policies # noqa: F401
|
||||||
@@ -22,6 +22,8 @@ from lerobot.configs import parser
|
|||||||
from lerobot.configs.default import EvalConfig
|
from lerobot.configs.default import EvalConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EvalPipelineConfig:
|
class EvalPipelineConfig:
|
||||||
@@ -35,24 +37,28 @@ class EvalPipelineConfig:
|
|||||||
job_name: str | None = None
|
job_name: str | None = None
|
||||||
seed: int | None = 1000
|
seed: int | None = 1000
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
if policy_path:
|
if policy_path:
|
||||||
cli_overrides = parser.get_cli_overrides("policy")
|
cli_overrides = parser.get_cli_overrides("policy")
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = Path(policy_path)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logger.warning(
|
||||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.job_name:
|
if not self.job_name:
|
||||||
if self.env is None:
|
if self.env is None:
|
||||||
self.job_name = f"{self.policy.type}"
|
self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
|
||||||
else:
|
else:
|
||||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
self.job_name = (
|
||||||
|
f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
|
||||||
|
|
||||||
if not self.output_dir:
|
if not self.output_dir:
|
||||||
now = dt.datetime.now()
|
now = dt.datetime.now()
|
||||||
|
|||||||
@@ -16,14 +16,19 @@ import inspect
|
|||||||
import pkgutil
|
import pkgutil
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentError
|
from argparse import ArgumentError
|
||||||
from collections.abc import Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from pkgutil import ModuleInfo
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any, TypeVar, cast
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
from lerobot.utils.utils import has_method
|
from lerobot.utils.utils import has_method
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable[..., object])
|
||||||
|
|
||||||
PATH_KEY = "path"
|
PATH_KEY = "path"
|
||||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||||
|
|
||||||
@@ -60,7 +65,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
|
||||||
"""Parse plugin-related arguments from command-line arguments.
|
"""Parse plugin-related arguments from command-line arguments.
|
||||||
|
|
||||||
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
||||||
@@ -127,7 +132,7 @@ def load_plugin(plugin_path: str) -> None:
|
|||||||
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
def iter_namespace(ns_pkg):
|
def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
|
||||||
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -148,6 +153,8 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
|
|||||||
|
|
||||||
|
|
||||||
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
||||||
|
if args is None:
|
||||||
|
return []
|
||||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||||
|
|
||||||
|
|
||||||
@@ -171,7 +178,8 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
|||||||
if isinstance(fields_to_filter, str):
|
if isinstance(fields_to_filter, str):
|
||||||
fields_to_filter = [fields_to_filter]
|
fields_to_filter = [fields_to_filter]
|
||||||
|
|
||||||
filtered_args = args
|
filtered_args = [] if args is None else list(args)
|
||||||
|
|
||||||
for field in fields_to_filter:
|
for field in fields_to_filter:
|
||||||
if get_path_arg(field, args):
|
if get_path_arg(field, args):
|
||||||
if get_type_arg(field, args):
|
if get_type_arg(field, args):
|
||||||
@@ -184,7 +192,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
|||||||
return filtered_args
|
return filtered_args
|
||||||
|
|
||||||
|
|
||||||
def wrap(config_path: Path | None = None):
|
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||||
"""
|
"""
|
||||||
HACK: Similar to draccus.wrap but does three additional things:
|
HACK: Similar to draccus.wrap but does three additional things:
|
||||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||||
@@ -195,9 +203,9 @@ def wrap(config_path: Path | None = None):
|
|||||||
from the CLI '.type' arguments
|
from the CLI '.type' arguments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper_outer(fn):
|
def wrapper_outer(fn: F) -> F:
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper_inner(*args, **kwargs):
|
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
|
||||||
argspec = inspect.getfullargspec(fn)
|
argspec = inspect.getfullargspec(fn)
|
||||||
argtype = argspec.annotations[argspec.args[0]]
|
argtype = argspec.annotations[argspec.args[0]]
|
||||||
if len(args) > 0 and type(args[0]) is argtype:
|
if len(args) > 0 and type(args[0]) is argtype:
|
||||||
@@ -225,6 +233,6 @@ def wrap(config_path: Path | None = None):
|
|||||||
response = fn(cfg, *args, **kwargs)
|
response = fn(cfg, *args, **kwargs)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
return wrapper_inner
|
return cast(F, wrapper_inner)
|
||||||
|
|
||||||
return wrapper_outer
|
return cast(Callable[[F], F], wrapper_outer)
|
||||||
|
|||||||
@@ -14,12 +14,12 @@
|
|||||||
import abc
|
import abc
|
||||||
import builtins
|
import builtins
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@@ -34,10 +34,11 @@ from lerobot.utils.hub import HubMixin
|
|||||||
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||||
|
|
||||||
T = TypeVar("T", bound="PreTrainedConfig")
|
T = TypeVar("T", bound="PreTrainedConfig")
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
|
||||||
"""
|
"""
|
||||||
Base configuration class for policy models.
|
Base configuration class for policy models.
|
||||||
|
|
||||||
@@ -62,7 +63,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
# automatic gradient scaling is used.
|
# automatic gradient scaling is used.
|
||||||
use_amp: bool = False
|
use_amp: bool = False
|
||||||
|
|
||||||
push_to_hub: bool = True
|
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
|
||||||
repo_id: str | None = None
|
repo_id: str | None = None
|
||||||
|
|
||||||
# Upload on private repository on the Hugging Face hub.
|
# Upload on private repository on the Hugging Face hub.
|
||||||
@@ -73,38 +74,41 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
license: str | None = None
|
license: str | None = None
|
||||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||||
pretrained_path: str | None = None
|
pretrained_path: Path | None = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
if not self.device or not is_torch_device_available(self.device):
|
if not self.device or not is_torch_device_available(self.device):
|
||||||
auto_device = auto_select_torch_device()
|
auto_device = auto_select_torch_device()
|
||||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||||
self.device = auto_device.type
|
self.device = auto_device.type
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
# Automatically deactivate AMP if necessary
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
if self.use_amp and not is_amp_available(self.device):
|
||||||
logging.warning(
|
logger.warning(
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||||
)
|
)
|
||||||
self.use_amp = False
|
self.use_amp = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return self.get_choice_name(self.__class__)
|
choice_name = self.get_choice_name(self.__class__)
|
||||||
|
if not isinstance(choice_name, str):
|
||||||
|
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||||
|
return choice_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def observation_delta_indices(self) -> list | None:
|
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def action_delta_indices(self) -> list | None:
|
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def reward_delta_indices(self) -> list | None:
|
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -154,13 +158,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
pretrained_name_or_path: str | Path,
|
pretrained_name_or_path: str | Path,
|
||||||
*,
|
*,
|
||||||
force_download: bool = False,
|
force_download: bool = False,
|
||||||
resume_download: bool = None,
|
resume_download: bool | None = None,
|
||||||
proxies: dict | None = None,
|
proxies: dict[Any, Any] | None = None,
|
||||||
token: str | bool | None = None,
|
token: str | bool | None = None,
|
||||||
cache_dir: str | Path | None = None,
|
cache_dir: str | Path | None = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
**policy_kwargs,
|
**policy_kwargs: Any,
|
||||||
) -> T:
|
) -> T:
|
||||||
model_id = str(pretrained_name_or_path)
|
model_id = str(pretrained_name_or_path)
|
||||||
config_file: str | None = None
|
config_file: str | None = None
|
||||||
@@ -168,7 +172,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
if CONFIG_NAME in os.listdir(model_id):
|
if CONFIG_NAME in os.listdir(model_id):
|
||||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||||
else:
|
else:
|
||||||
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config_file = hf_hub_download(
|
config_file = hf_hub_download(
|
||||||
@@ -194,6 +198,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
with draccus.config_type("json"):
|
with draccus.config_type("json"):
|
||||||
orig_config = draccus.parse(cls, config_file, args=[])
|
orig_config = draccus.parse(cls, config_file, args=[])
|
||||||
|
|
||||||
|
if config_file is None:
|
||||||
|
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||||
|
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import datetime as dt
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@@ -63,18 +64,16 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
scheduler: LRSchedulerConfig | None = None
|
scheduler: LRSchedulerConfig | None = None
|
||||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||||
|
checkpoint_path: Path | None = field(init=False, default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def validate(self) -> None:
|
||||||
self.checkpoint_path = None
|
|
||||||
|
|
||||||
def validate(self):
|
|
||||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
if policy_path:
|
if policy_path:
|
||||||
# Only load the policy config
|
# Only load the policy config
|
||||||
cli_overrides = parser.get_cli_overrides("policy")
|
cli_overrides = parser.get_cli_overrides("policy")
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = Path(policy_path)
|
||||||
elif self.resume:
|
elif self.resume:
|
||||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||||
config_path = parser.parse_arg("config_path")
|
config_path = parser.parse_arg("config_path")
|
||||||
@@ -82,14 +81,22 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not Path(config_path).resolve().exists():
|
if not Path(config_path).resolve().exists():
|
||||||
raise NotADirectoryError(
|
raise NotADirectoryError(
|
||||||
f"{config_path=} is expected to be a local path. "
|
f"{config_path=} is expected to be a local path. "
|
||||||
"Resuming from the hub is not supported for now."
|
"Resuming from the hub is not supported for now."
|
||||||
)
|
)
|
||||||
policy_path = Path(config_path).parent
|
|
||||||
self.policy.pretrained_path = policy_path
|
policy_dir = Path(config_path).parent
|
||||||
self.checkpoint_path = policy_path.parent
|
if self.policy is not None:
|
||||||
|
self.policy.pretrained_path = policy_dir
|
||||||
|
self.checkpoint_path = policy_dir.parent
|
||||||
|
|
||||||
|
if self.policy is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
||||||
|
)
|
||||||
|
|
||||||
if not self.job_name:
|
if not self.job_name:
|
||||||
if self.env is None:
|
if self.env is None:
|
||||||
@@ -126,8 +133,8 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||||
return ["policy"]
|
return ["policy"]
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return draccus.encode(self)
|
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path) -> None:
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||||
@@ -139,13 +146,13 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
pretrained_name_or_path: str | Path,
|
pretrained_name_or_path: str | Path,
|
||||||
*,
|
*,
|
||||||
force_download: bool = False,
|
force_download: bool = False,
|
||||||
resume_download: bool = None,
|
resume_download: bool | None = None,
|
||||||
proxies: dict | None = None,
|
proxies: dict[Any, Any] | None = None,
|
||||||
token: str | bool | None = None,
|
token: str | bool | None = None,
|
||||||
cache_dir: str | Path | None = None,
|
cache_dir: str | Path | None = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> "TrainPipelineConfig":
|
) -> "TrainPipelineConfig":
|
||||||
model_id = str(pretrained_name_or_path)
|
model_id = str(pretrained_name_or_path)
|
||||||
config_file: str | None = None
|
config_file: str | None = None
|
||||||
@@ -181,4 +188,6 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
# NOTE: In RL, we don't need an offline dataset
|
||||||
|
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||||
|
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||||
|
|||||||
@@ -42,4 +42,4 @@ class NormalizationMode(str, Enum):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class PolicyFeature:
|
class PolicyFeature:
|
||||||
type: FeatureType
|
type: FeatureType
|
||||||
shape: tuple
|
shape: tuple[int, ...]
|
||||||
|
|||||||
Reference in New Issue
Block a user