47 lines
1.9 KiB
Python
47 lines
1.9 KiB
Python
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union
|
|
|
|
|
|
def get_scheduler(
|
|
name: Union[str, SchedulerType],
|
|
optimizer: Optimizer,
|
|
num_warmup_steps: Optional[int] = None,
|
|
num_training_steps: Optional[int] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Added kwargs vs diffuser's original implementation
|
|
|
|
Unified API to get any scheduler from its name.
|
|
|
|
Args:
|
|
name (`str` or `SchedulerType`):
|
|
The name of the scheduler to use.
|
|
optimizer (`torch.optim.Optimizer`):
|
|
The optimizer that will be used during training.
|
|
num_warmup_steps (`int`, *optional*):
|
|
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
|
num_training_steps (`int``, *optional*):
|
|
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
|
"""
|
|
name = SchedulerType(name)
|
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
|
if name == SchedulerType.CONSTANT:
|
|
return schedule_func(optimizer, **kwargs)
|
|
|
|
# All other schedulers require `num_warmup_steps`
|
|
if num_warmup_steps is None:
|
|
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
|
|
|
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
|
|
|
|
# All other schedulers require `num_training_steps`
|
|
if num_training_steps is None:
|
|
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
|
|
|
return schedule_func(
|
|
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs
|
|
)
|