import abc import math from dataclasses import asdict, dataclass import draccus from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler @dataclass class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC): num_warmup_steps: int @property def type(self) -> str: return self.get_choice_name(self.__class__) @abc.abstractmethod def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None: raise NotImplementedError @LRSchedulerConfig.register_subclass("diffuser") @dataclass class DiffuserSchedulerConfig(LRSchedulerConfig): name: str = "cosine" num_warmup_steps: int | None = None def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: from diffusers.optimization import get_scheduler kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer} return get_scheduler(**kwargs) @LRSchedulerConfig.register_subclass("vqbet") @dataclass class VQBeTSchedulerConfig(LRSchedulerConfig): num_warmup_steps: int num_vqvae_training_steps: int num_cycles: float = 0.5 def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: def lr_lambda(current_step): if current_step < self.num_vqvae_training_steps: return float(1) else: adjusted_step = current_step - self.num_vqvae_training_steps if adjusted_step < self.num_warmup_steps: return float(adjusted_step) / float(max(1, self.num_warmup_steps)) progress = float(adjusted_step - self.num_warmup_steps) / float( max(1, num_training_steps - self.num_warmup_steps) ) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress))) return LambdaLR(optimizer, lr_lambda, -1)