backup wip
This commit is contained in:
@@ -32,7 +32,7 @@ class EMAModel:
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
self.decay = 0.0
|
||||
self.alpha = 0.0
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
@@ -49,23 +49,20 @@ class EMAModel:
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, new_model):
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
self.alpha = self.get_decay(self.optimization_step)
|
||||
|
||||
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
|
||||
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
|
||||
# Iterate over immediate parameters only.
|
||||
for param, ema_param in zip(
|
||||
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
|
||||
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
|
||||
):
|
||||
# iterative over immediate parameters only.
|
||||
if isinstance(param, dict):
|
||||
raise RuntimeError("Dict parameter not supported")
|
||||
|
||||
if isinstance(module, _BatchNorm):
|
||||
# skip batchnorms
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
elif not param.requires_grad:
|
||||
if isinstance(module, _BatchNorm) or not param.requires_grad:
|
||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
else:
|
||||
ema_param.mul_(self.decay)
|
||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
||||
ema_param.mul_(self.alpha)
|
||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
|
||||
|
||||
self.optimization_step += 1
|
||||
|
||||
Reference in New Issue
Block a user