diff options
Diffstat (limited to 'trm/models/ema.py')
| -rw-r--r-- | trm/models/ema.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/trm/models/ema.py b/trm/models/ema.py new file mode 100644 index 0000000..2e52933 --- /dev/null +++ b/trm/models/ema.py @@ -0,0 +1,40 @@ +import copy +import torch.nn as nn + +class EMAHelper(object): + def __init__(self, mu=0.999): + self.mu = mu + self.shadow = {} + + def register(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + self.shadow[name] = param.data.clone() + + def update(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data + + def ema(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + param.data.copy_(self.shadow[name].data) + + def ema_copy(self, module): + module_copy = copy.deepcopy(module) + self.ema(module_copy) + return module_copy + + def state_dict(self): + return self.shadow + + def load_state_dict(self, state_dict): + self.shadow = state_dict + |
