Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
57 lines
1.4 KiB
Python
57 lines
1.4 KiB
Python
import abc
|
|
from dataclasses import asdict, dataclass
|
|
|
|
import draccus
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|
lr: float
|
|
betas: tuple[float, float]
|
|
eps: float
|
|
weight_decay: float
|
|
grad_clip_norm: float
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return self.get_choice_name(self.__class__)
|
|
|
|
@classmethod
|
|
def default_choice_name(cls) -> str | None:
|
|
return "adam"
|
|
|
|
@abc.abstractmethod
|
|
def build(self) -> torch.optim.Optimizer:
|
|
raise NotImplementedError
|
|
|
|
|
|
@OptimizerConfig.register_subclass("adam")
|
|
@dataclass
|
|
class AdamConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
betas: tuple[float, float] = (0.9, 0.999)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.Adam(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("adamw")
|
|
@dataclass
|
|
class AdamWConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
betas: tuple[float, float] = (0.9, 0.999)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 1e-2
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.AdamW(params, **kwargs)
|