[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
2945bbb221
commit
7c05755823
@@ -33,9 +33,7 @@ class DatasetConfig:
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(
|
||||
default_factory=ImageTransformsConfig
|
||||
)
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
|
||||
@@ -40,9 +40,7 @@ class EvalPipelineConfig:
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
else:
|
||||
|
||||
@@ -29,9 +29,7 @@ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
draccus.set_config_type("json")
|
||||
|
||||
|
||||
def get_cli_overrides(
|
||||
field_name: str, args: Sequence[str] | None = None
|
||||
) -> list[str] | None:
|
||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||
"""Parses arguments from cli at a given nested attribute level.
|
||||
|
||||
For example, supposing the main script was called with:
|
||||
@@ -158,9 +156,7 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[
|
||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
def filter_path_args(
|
||||
fields_to_filter: str | list[str], args: Sequence[str] | None = None
|
||||
) -> list[str]:
|
||||
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
|
||||
"""
|
||||
Filters command-line arguments related to fields with specific path arguments.
|
||||
|
||||
@@ -188,9 +184,7 @@ def filter_path_args(
|
||||
argument=None,
|
||||
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
|
||||
)
|
||||
filtered_args = [
|
||||
arg for arg in filtered_args if not arg.startswith(f"--{field}.")
|
||||
]
|
||||
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
|
||||
|
||||
return filtered_args
|
||||
|
||||
@@ -222,9 +216,7 @@ def wrap(config_path: Path | None = None):
|
||||
load_plugin(plugin_path)
|
||||
except PluginLoadError as e:
|
||||
# add the relevant CLI arg to the error message
|
||||
raise PluginLoadError(
|
||||
f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}"
|
||||
) from e
|
||||
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
|
||||
cli_args = filter_arg(plugin_cli_arg, cli_args)
|
||||
config_path_cli = parse_arg("config_path", cli_args)
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
@@ -234,9 +226,7 @@ def wrap(config_path: Path | None = None):
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype, config_path=config_path, args=cli_args
|
||||
)
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
@@ -68,9 +68,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
self.pretrained_path = None
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(
|
||||
f"Device '{self.device}' is not available. Switching to '{auto_device}'."
|
||||
)
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
@@ -124,11 +122,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
return {
|
||||
key: ft
|
||||
for key, ft in self.input_features.items()
|
||||
if ft.type is FeatureType.VISUAL
|
||||
}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
|
||||
@@ -73,9 +73,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
@@ -99,11 +97,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
|
||||
if (
|
||||
not self.resume
|
||||
and isinstance(self.output_dir, Path)
|
||||
and self.output_dir.is_dir()
|
||||
):
|
||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||
raise FileExistsError(
|
||||
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
|
||||
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
||||
@@ -114,16 +108,10 @@ class TrainPipelineConfig(HubMixin):
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError(
|
||||
"LeRobotMultiDataset is not currently implemented."
|
||||
)
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
|
||||
if not self.use_policy_training_preset and (
|
||||
self.optimizer is None or self.scheduler is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Optimizer and Scheduler must be set when the policy presets are not used."
|
||||
)
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
elif self.use_policy_training_preset and not self.resume:
|
||||
self.optimizer = self.policy.get_optimizer_preset()
|
||||
self.scheduler = self.policy.get_scheduler_preset()
|
||||
|
||||
Reference in New Issue
Block a user