summaryrefslogtreecommitdiff
path: root/trm/models/ema.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /trm/models/ema.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'trm/models/ema.py')
-rw-r--r--trm/models/ema.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/trm/models/ema.py b/trm/models/ema.py
new file mode 100644
index 0000000..2e52933
--- /dev/null
+++ b/trm/models/ema.py
@@ -0,0 +1,40 @@
+import copy
+import torch.nn as nn
+
+class EMAHelper(object):
+ def __init__(self, mu=0.999):
+ self.mu = mu
+ self.shadow = {}
+
+ def register(self, module):
+ if isinstance(module, nn.DataParallel):
+ module = module.module
+ for name, param in module.named_parameters():
+ if param.requires_grad:
+ self.shadow[name] = param.data.clone()
+
+ def update(self, module):
+ if isinstance(module, nn.DataParallel):
+ module = module.module
+ for name, param in module.named_parameters():
+ if param.requires_grad:
+ self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
+
+ def ema(self, module):
+ if isinstance(module, nn.DataParallel):
+ module = module.module
+ for name, param in module.named_parameters():
+ if param.requires_grad:
+ param.data.copy_(self.shadow[name].data)
+
+ def ema_copy(self, module):
+ module_copy = copy.deepcopy(module)
+ self.ema(module_copy)
+ return module_copy
+
+ def state_dict(self):
+ return self.shadow
+
+ def load_state_dict(self, state_dict):
+ self.shadow = state_dict
+