forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
@@ -33,7 +33,9 @@ class DatasetConfig:
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
image_transforms: ImageTransformsConfig = field(
|
||||
default_factory=ImageTransformsConfig
|
||||
)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
|
||||
@@ -40,7 +40,9 @@ class EvalPipelineConfig:
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
else:
|
||||
|
||||
@@ -29,7 +29,9 @@ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
draccus.set_config_type("json")
|
||||
|
||||
|
||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||
def get_cli_overrides(
|
||||
field_name: str, args: Sequence[str] | None = None
|
||||
) -> list[str] | None:
|
||||
"""Parses arguments from cli at a given nested attribute level.
|
||||
|
||||
For example, supposing the main script was called with:
|
||||
@@ -42,7 +44,10 @@ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> lis
|
||||
args = sys.argv[1:]
|
||||
attr_level_args = []
|
||||
detect_string = f"--{field_name}."
|
||||
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
|
||||
exclude_strings = (
|
||||
f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=",
|
||||
f"--{field_name}.{PATH_KEY}=",
|
||||
)
|
||||
for arg in args:
|
||||
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
|
||||
denested_arg = f"--{arg.removeprefix(detect_string)}"
|
||||
@@ -153,7 +158,9 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[
|
||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
|
||||
def filter_path_args(
|
||||
fields_to_filter: str | list[str], args: Sequence[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Filters command-line arguments related to fields with specific path arguments.
|
||||
|
||||
@@ -181,7 +188,9 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
argument=None,
|
||||
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
|
||||
)
|
||||
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
|
||||
filtered_args = [
|
||||
arg for arg in filtered_args if not arg.startswith(f"--{field}.")
|
||||
]
|
||||
|
||||
return filtered_args
|
||||
|
||||
@@ -213,7 +222,9 @@ def wrap(config_path: Path | None = None):
|
||||
load_plugin(plugin_path)
|
||||
except PluginLoadError as e:
|
||||
# add the relevant CLI arg to the error message
|
||||
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
|
||||
raise PluginLoadError(
|
||||
f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}"
|
||||
) from e
|
||||
cli_args = filter_arg(plugin_cli_arg, cli_args)
|
||||
config_path_cli = parse_arg("config_path", cli_args)
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
@@ -223,7 +234,9 @@ def wrap(config_path: Path | None = None):
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype, config_path=config_path, args=cli_args
|
||||
)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
@@ -26,7 +26,11 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
from lerobot.common.optim.optimizers import OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
from lerobot.common.utils.utils import (
|
||||
auto_select_torch_device,
|
||||
is_amp_available,
|
||||
is_torch_device_available,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
||||
@@ -64,7 +68,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
self.pretrained_path = None
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
logging.warning(
|
||||
f"Device '{self.device}' is not available. Switching to '{auto_device}'."
|
||||
)
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
@@ -118,7 +124,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
return {
|
||||
key: ft
|
||||
for key, ft in self.input_features.items()
|
||||
if ft.type is FeatureType.VISUAL
|
||||
}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
|
||||
@@ -73,7 +73,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = policy_path
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
@@ -97,7 +99,11 @@ class TrainPipelineConfig(HubMixin):
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
|
||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||
if (
|
||||
not self.resume
|
||||
and isinstance(self.output_dir, Path)
|
||||
and self.output_dir.is_dir()
|
||||
):
|
||||
raise FileExistsError(
|
||||
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
|
||||
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
||||
@@ -108,10 +114,16 @@ class TrainPipelineConfig(HubMixin):
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
raise NotImplementedError(
|
||||
"LeRobotMultiDataset is not currently implemented."
|
||||
)
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
if not self.use_policy_training_preset and (
|
||||
self.optimizer is None or self.scheduler is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Optimizer and Scheduler must be set when the policy presets are not used."
|
||||
)
|
||||
elif self.use_policy_training_preset and not self.resume:
|
||||
self.optimizer = self.policy.get_optimizer_preset()
|
||||
self.scheduler = self.policy.get_scheduler_preset()
|
||||
@@ -125,7 +137,10 @@ class TrainPipelineConfig(HubMixin):
|
||||
return draccus.encode(self)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
with (
|
||||
open(save_directory / TRAIN_CONFIG_NAME, "w") as f,
|
||||
draccus.config_type("json"),
|
||||
):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user