From 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:00:39 -0600 Subject: Initial implementation: DAGFormer Phase 1 - olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 --- src/training/checkpointing.py | 92 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 src/training/checkpointing.py (limited to 'src/training/checkpointing.py') diff --git a/src/training/checkpointing.py b/src/training/checkpointing.py new file mode 100644 index 0000000..9ff02df --- /dev/null +++ b/src/training/checkpointing.py @@ -0,0 +1,92 @@ +"""Checkpoint save/load for predictor + optimizer + schedule state. + +Only saves trainable components (predictor MLP, optimizer, schedule state). +Frozen models (OLMo, Qwen) are not checkpointed — they load from HuggingFace. +""" + +from __future__ import annotations + +import os +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.optim as optim + + +def save_checkpoint( + save_dir: str, + step: int, + predictor: nn.Module, + optimizer: optim.Optimizer, + scheduler: Any, + best_eval_nll: float, + extra: Optional[dict] = None, +) -> str: + """Save training checkpoint. + + Args: + save_dir: directory to save checkpoint + step: current global step + predictor: the structure predictor (only MLP params are saved) + optimizer: AdamW optimizer + scheduler: LR scheduler + best_eval_nll: best eval NLL so far + extra: any additional state to save + + Returns: + path: path to saved checkpoint + """ + os.makedirs(save_dir, exist_ok=True) + path = os.path.join(save_dir, f"checkpoint_step{step}.pt") + + state = { + "step": step, + "predictor_state_dict": predictor.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None, + "best_eval_nll": best_eval_nll, + } + if extra: + state.update(extra) + + torch.save(state, path) + print(f"Checkpoint saved: {path}") + return path + + +def load_checkpoint( + path: str, + predictor: nn.Module, + optimizer: Optional[optim.Optimizer] = None, + scheduler: Optional[Any] = None, + device: Optional[torch.device] = None, +) -> dict: + """Load training checkpoint. + + Args: + path: path to checkpoint file + predictor: structure predictor to load weights into + optimizer: optimizer to restore state (optional — skip for eval) + scheduler: LR scheduler to restore state (optional) + device: device to map tensors to + + Returns: + state dict with step, best_eval_nll, and any extras + """ + map_location = device if device is not None else "cpu" + state = torch.load(path, map_location=map_location) + + predictor.load_state_dict(state["predictor_state_dict"]) + print(f"Predictor state loaded from {path}") + + if optimizer is not None and "optimizer_state_dict" in state: + optimizer.load_state_dict(state["optimizer_state_dict"]) + + if scheduler is not None and state.get("scheduler_state_dict") is not None: + scheduler.load_state_dict(state["scheduler_state_dict"]) + + return { + "step": state["step"], + "best_eval_nll": state.get("best_eval_nll", float("inf")), + } -- cgit v1.2.3