summaryrefslogtreecommitdiff
path: root/trm/models/ema.py
diff options
context:
space:
mode:
Diffstat (limited to 'trm/models/ema.py')
-rw-r--r--trm/models/ema.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/trm/models/ema.py b/trm/models/ema.py
new file mode 100644
index 0000000..2e52933
--- /dev/null
+++ b/trm/models/ema.py
@@ -0,0 +1,40 @@
+import copy
+import torch.nn as nn
+
+class EMAHelper(object):
+ def __init__(self, mu=0.999):
+ self.mu = mu
+ self.shadow = {}
+
+ def register(self, module):
+ if isinstance(module, nn.DataParallel):
+ module = module.module
+ for name, param in module.named_parameters():
+ if param.requires_grad:
+ self.shadow[name] = param.data.clone()
+
+ def update(self, module):
+ if isinstance(module, nn.DataParallel):
+ module = module.module
+ for name, param in module.named_parameters():
+ if param.requires_grad:
+ self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
+
+ def ema(self, module):
+ if isinstance(module, nn.DataParallel):
+ module = module.module
+ for name, param in module.named_parameters():
+ if param.requires_grad:
+ param.data.copy_(self.shadow[name].data)
+
+ def ema_copy(self, module):
+ module_copy = copy.deepcopy(module)
+ self.ema(module_copy)
+ return module_copy
+
+ def state_dict(self):
+ return self.shadow
+
+ def load_state_dict(self, state_dict):
+ self.shadow = state_dict
+