[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by Michel Aractingi
parent cdcf346061
commit 1c8daf11fd
95 changed files with 1592 additions and 491 deletions

View File

@@ -33,7 +33,9 @@ class DatasetConfig:
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
image_transforms: ImageTransformsConfig = field(
default_factory=ImageTransformsConfig
)
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)

View File

@@ -40,7 +40,9 @@ class EvalPipelineConfig:
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=cli_overrides
)
self.policy.pretrained_path = policy_path
else:

View File

@@ -29,7 +29,9 @@ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
draccus.set_config_type("json")
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
def get_cli_overrides(
field_name: str, args: Sequence[str] | None = None
) -> list[str] | None:
"""Parses arguments from cli at a given nested attribute level.
For example, supposing the main script was called with:
@@ -42,7 +44,10 @@ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> lis
args = sys.argv[1:]
attr_level_args = []
detect_string = f"--{field_name}."
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
exclude_strings = (
f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=",
f"--{field_name}.{PATH_KEY}=",
)
for arg in args:
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
denested_arg = f"--{arg.removeprefix(detect_string)}"
@@ -153,7 +158,9 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
def filter_path_args(
fields_to_filter: str | list[str], args: Sequence[str] | None = None
) -> list[str]:
"""
Filters command-line arguments related to fields with specific path arguments.
@@ -181,7 +188,9 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
argument=None,
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
)
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
filtered_args = [
arg for arg in filtered_args if not arg.startswith(f"--{field}.")
]
return filtered_args
@@ -213,7 +222,9 @@ def wrap(config_path: Path | None = None):
load_plugin(plugin_path)
except PluginLoadError as e:
# add the relevant CLI arg to the error message
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
raise PluginLoadError(
f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}"
) from e
cli_args = filter_arg(plugin_cli_arg, cli_args)
config_path_cli = parse_arg("config_path", cli_args)
if has_method(argtype, "__get_path_fields__"):
@@ -223,7 +234,9 @@ def wrap(config_path: Path | None = None):
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
cfg = draccus.parse(
config_class=argtype, config_path=config_path, args=cli_args
)
response = fn(cfg, *args, **kwargs)
return response

View File

@@ -26,7 +26,11 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.common.utils.utils import (
auto_select_torch_device,
is_amp_available,
is_torch_device_available,
)
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Generic variable that is either PreTrainedConfig or a subclass thereof
@@ -64,7 +68,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
logging.warning(
f"Device '{self.device}' is not available. Switching to '{auto_device}'."
)
self.device = auto_device.type
# Automatically deactivate AMP if necessary
@@ -118,7 +124,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property
def image_features(self) -> dict[str, PolicyFeature]:
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
return {
key: ft
for key, ft in self.input_features.items()
if ft.type is FeatureType.VISUAL
}
@property
def action_feature(self) -> PolicyFeature | None:

View File

@@ -73,7 +73,9 @@ class TrainPipelineConfig(HubMixin):
if policy_path:
# Only load the policy config
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=cli_overrides
)
self.policy.pretrained_path = policy_path
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
@@ -97,7 +99,11 @@ class TrainPipelineConfig(HubMixin):
else:
self.job_name = f"{self.env.type}_{self.policy.type}"
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
if (
not self.resume
and isinstance(self.output_dir, Path)
and self.output_dir.is_dir()
):
raise FileExistsError(
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
f"Please change your output directory so that {self.output_dir} is not overwritten."
@@ -108,10 +114,16 @@ class TrainPipelineConfig(HubMixin):
self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
raise NotImplementedError(
"LeRobotMultiDataset is not currently implemented."
)
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
if not self.use_policy_training_preset and (
self.optimizer is None or self.scheduler is None
):
raise ValueError(
"Optimizer and Scheduler must be set when the policy presets are not used."
)
elif self.use_policy_training_preset and not self.resume:
self.optimizer = self.policy.get_optimizer_preset()
self.scheduler = self.policy.get_scheduler_preset()
@@ -125,7 +137,10 @@ class TrainPipelineConfig(HubMixin):
return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
with (
open(save_directory / TRAIN_CONFIG_NAME, "w") as f,
draccus.config_type("json"),
):
draccus.dump(self, f, indent=4)
@classmethod