Co-authored-by: Simon Alibert <simon.alibert@huggingface.co> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pablo <pablo.montalvo.leroux@gmail.com>
71 lines
1.7 KiB
Python
71 lines
1.7 KiB
Python
import abc
|
|
from dataclasses import asdict, dataclass
|
|
|
|
import draccus
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|
lr: float
|
|
weight_decay: float
|
|
grad_clip_norm: float
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return self.get_choice_name(self.__class__)
|
|
|
|
@classmethod
|
|
def default_choice_name(cls) -> str | None:
|
|
return "adam"
|
|
|
|
@abc.abstractmethod
|
|
def build(self) -> torch.optim.Optimizer:
|
|
raise NotImplementedError
|
|
|
|
|
|
@OptimizerConfig.register_subclass("adam")
|
|
@dataclass
|
|
class AdamConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
betas: tuple[float, float] = (0.9, 0.999)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.Adam(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("adamw")
|
|
@dataclass
|
|
class AdamWConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
betas: tuple[float, float] = (0.9, 0.999)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 1e-2
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.AdamW(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("sgd")
|
|
@dataclass
|
|
class SGDConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
momentum: float = 0.0
|
|
dampening: float = 0.0
|
|
nesterov: bool = False
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.SGD(params, **kwargs)
|