diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /trm/models/ema.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'trm/models/ema.py')
| -rw-r--r-- | trm/models/ema.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/trm/models/ema.py b/trm/models/ema.py new file mode 100644 index 0000000..2e52933 --- /dev/null +++ b/trm/models/ema.py @@ -0,0 +1,40 @@ +import copy +import torch.nn as nn + +class EMAHelper(object): + def __init__(self, mu=0.999): + self.mu = mu + self.shadow = {} + + def register(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + self.shadow[name] = param.data.clone() + + def update(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data + + def ema(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + param.data.copy_(self.shadow[name].data) + + def ema_copy(self, module): + module_copy = copy.deepcopy(module) + self.ema(module_copy) + return module_copy + + def state_dict(self): + return self.shadow + + def load_state_dict(self, state_dict): + self.shadow = state_dict + |
