summaryrefslogtreecommitdiff
path: root/src/training/checkpointing.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
commit13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (patch)
tree073534138604c1c49021ca7e334322262129f6ac /src/training/checkpointing.py
Initial implementation: DAGFormer Phase 1
- olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src/training/checkpointing.py')
-rw-r--r--src/training/checkpointing.py92
1 files changed, 92 insertions, 0 deletions
diff --git a/src/training/checkpointing.py b/src/training/checkpointing.py
new file mode 100644
index 0000000..9ff02df
--- /dev/null
+++ b/src/training/checkpointing.py
@@ -0,0 +1,92 @@
+"""Checkpoint save/load for predictor + optimizer + schedule state.
+
+Only saves trainable components (predictor MLP, optimizer, schedule state).
+Frozen models (OLMo, Qwen) are not checkpointed — they load from HuggingFace.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+
+def save_checkpoint(
+ save_dir: str,
+ step: int,
+ predictor: nn.Module,
+ optimizer: optim.Optimizer,
+ scheduler: Any,
+ best_eval_nll: float,
+ extra: Optional[dict] = None,
+) -> str:
+ """Save training checkpoint.
+
+ Args:
+ save_dir: directory to save checkpoint
+ step: current global step
+ predictor: the structure predictor (only MLP params are saved)
+ optimizer: AdamW optimizer
+ scheduler: LR scheduler
+ best_eval_nll: best eval NLL so far
+ extra: any additional state to save
+
+ Returns:
+ path: path to saved checkpoint
+ """
+ os.makedirs(save_dir, exist_ok=True)
+ path = os.path.join(save_dir, f"checkpoint_step{step}.pt")
+
+ state = {
+ "step": step,
+ "predictor_state_dict": predictor.state_dict(),
+ "optimizer_state_dict": optimizer.state_dict(),
+ "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
+ "best_eval_nll": best_eval_nll,
+ }
+ if extra:
+ state.update(extra)
+
+ torch.save(state, path)
+ print(f"Checkpoint saved: {path}")
+ return path
+
+
+def load_checkpoint(
+ path: str,
+ predictor: nn.Module,
+ optimizer: Optional[optim.Optimizer] = None,
+ scheduler: Optional[Any] = None,
+ device: Optional[torch.device] = None,
+) -> dict:
+ """Load training checkpoint.
+
+ Args:
+ path: path to checkpoint file
+ predictor: structure predictor to load weights into
+ optimizer: optimizer to restore state (optional — skip for eval)
+ scheduler: LR scheduler to restore state (optional)
+ device: device to map tensors to
+
+ Returns:
+ state dict with step, best_eval_nll, and any extras
+ """
+ map_location = device if device is not None else "cpu"
+ state = torch.load(path, map_location=map_location)
+
+ predictor.load_state_dict(state["predictor_state_dict"])
+ print(f"Predictor state loaded from {path}")
+
+ if optimizer is not None and "optimizer_state_dict" in state:
+ optimizer.load_state_dict(state["optimizer_state_dict"])
+
+ if scheduler is not None and state.get("scheduler_state_dict") is not None:
+ scheduler.load_state_dict(state["scheduler_state_dict"])
+
+ return {
+ "step": state["step"],
+ "best_eval_nll": state.get("best_eval_nll", float("inf")),
+ }