[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by Michel Aractingi
parent cdcf346061
commit 1c8daf11fd
95 changed files with 1592 additions and 491 deletions

View File

@@ -73,7 +73,9 @@ class TrainPipelineConfig(HubMixin):
if policy_path:
# Only load the policy config
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=cli_overrides
)
self.policy.pretrained_path = policy_path
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
@@ -97,7 +99,11 @@ class TrainPipelineConfig(HubMixin):
else:
self.job_name = f"{self.env.type}_{self.policy.type}"
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
if (
not self.resume
and isinstance(self.output_dir, Path)
and self.output_dir.is_dir()
):
raise FileExistsError(
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
f"Please change your output directory so that {self.output_dir} is not overwritten."
@@ -108,10 +114,16 @@ class TrainPipelineConfig(HubMixin):
self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
raise NotImplementedError(
"LeRobotMultiDataset is not currently implemented."
)
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
if not self.use_policy_training_preset and (
self.optimizer is None or self.scheduler is None
):
raise ValueError(
"Optimizer and Scheduler must be set when the policy presets are not used."
)
elif self.use_policy_training_preset and not self.resume:
self.optimizer = self.policy.get_optimizer_preset()
self.scheduler = self.policy.get_scheduler_preset()
@@ -125,7 +137,10 @@ class TrainPipelineConfig(HubMixin):
return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
with (
open(save_directory / TRAIN_CONFIG_NAME, "w") as f,
draccus.config_type("json"),
):
draccus.dump(self, f, indent=4)
@classmethod