backup wip

This commit is contained in:
Alexander Soare
2024-04-11 18:33:54 +01:00
parent 94cc22da9e
commit 5666ec3ec7
3 changed files with 53 additions and 89 deletions

View File

@@ -32,7 +32,7 @@ class EMAModel:
self.min_value = min_value
self.max_value = max_value
self.decay = 0.0
self.alpha = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
@@ -49,23 +49,20 @@ class EMAModel:
@torch.no_grad()
def step(self, new_model):
self.decay = self.get_decay(self.optimization_step)
self.alpha = self.get_decay(self.optimization_step)
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
# Iterate over immediate parameters only.
for param, ema_param in zip(
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
):
# iterative over immediate parameters only.
if isinstance(param, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, _BatchNorm):
# skip batchnorms
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
elif not param.requires_grad:
if isinstance(module, _BatchNorm) or not param.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.decay)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
ema_param.mul_(self.alpha)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
self.optimization_step += 1