[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by Michel Aractingi
parent cdcf346061
commit 1c8daf11fd
95 changed files with 1592 additions and 491 deletions

View File

@@ -34,7 +34,13 @@ def make_optimizer_and_scheduler(
Returns:
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
"""
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
params = (
policy.get_optim_params()
if cfg.use_policy_training_preset
else policy.parameters()
)
optimizer = cfg.optimizer.build(params)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
lr_scheduler = (
cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
)
return optimizer, lr_scheduler

View File

@@ -102,7 +102,9 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
def load_optimizer_state(
optimizer: torch.optim.Optimizer, save_dir: Path
) -> torch.optim.Optimizer:
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)

View File

@@ -36,7 +36,9 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
return self.get_choice_name(self.__class__)
@abc.abstractmethod
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
def build(
self, optimizer: Optimizer, num_training_steps: int
) -> LRScheduler | None:
raise NotImplementedError
@@ -49,7 +51,11 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
from diffusers.optimization import get_scheduler
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
kwargs = {
**asdict(self),
"num_training_steps": num_training_steps,
"optimizer": optimizer,
}
return get_scheduler(**kwargs)
@@ -71,7 +77,14 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
progress = float(adjusted_step - self.num_warmup_steps) / float(
max(1, num_training_steps - self.num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
return max(
0.0,
0.5
* (
1.0
+ math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)
),
)
return LambdaLR(optimizer, lr_lambda, -1)
@@ -98,7 +111,9 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
def cosine_decay_schedule(current_step):
step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
cosine_decay = 0.5 * (
1 + math.cos(math.pi * step / self.num_decay_steps)
)
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed
@@ -117,6 +132,8 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
state_dict = deserialize_json_into_object(
save_dir / SCHEDULER_STATE, scheduler.state_dict()
)
scheduler.load_state_dict(state_dict)
return scheduler