From 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sat, 13 Jun 2026 12:35:36 -0500 Subject: rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipeline Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 --- trm/models/ema.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 trm/models/ema.py (limited to 'trm/models/ema.py') diff --git a/trm/models/ema.py b/trm/models/ema.py new file mode 100644 index 0000000..2e52933 --- /dev/null +++ b/trm/models/ema.py @@ -0,0 +1,40 @@ +import copy +import torch.nn as nn + +class EMAHelper(object): + def __init__(self, mu=0.999): + self.mu = mu + self.shadow = {} + + def register(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + self.shadow[name] = param.data.clone() + + def update(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data + + def ema(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + param.data.copy_(self.shadow[name].data) + + def ema_copy(self, module): + module_copy = copy.deepcopy(module) + self.ema(module_copy) + return module_copy + + def state_dict(self): + return self.shadow + + def load_state_dict(self, state_dict): + self.shadow = state_dict + -- cgit v1.2.3