summaryrefslogtreecommitdiff
path: root/src/training/checkpointing.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-10 09:50:33 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-10 09:50:33 -0600
commit039c12d3cf7178db6a7d80b02cf022d67231014e (patch)
treeb3104310bfaced0d992729f59f1a7ef2e769c6bd /src/training/checkpointing.py
parent80579d6cc254d337a23e71404ae7ecab1849d1e5 (diff)
Add auto-resume checkpointing, S1/S2 configs, and experiment results
- Auto-resume: find latest checkpoint in save_dir on startup - SIGUSR1 handler: save checkpoint before SLURM timeout - S1 config (constant tau=5, identity init verification) - S2 config (constant tau=2, gradient flow check) - Experiment results tracker with S0/S1 data - Speed estimates and experiment plan Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src/training/checkpointing.py')
-rw-r--r--src/training/checkpointing.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/src/training/checkpointing.py b/src/training/checkpointing.py
index 9ff02df..b53ce4f 100644
--- a/src/training/checkpointing.py
+++ b/src/training/checkpointing.py
@@ -6,7 +6,9 @@ Frozen models (OLMo, Qwen) are not checkpointed — they load from HuggingFace.
from __future__ import annotations
+import glob
import os
+import re
from typing import Any, Optional
import torch
@@ -55,6 +57,35 @@ def save_checkpoint(
return path
+def find_latest_checkpoint(save_dir: str) -> Optional[str]:
+ """Find the latest checkpoint in save_dir by step number.
+
+ Returns:
+ Path to latest checkpoint, or None if no checkpoints found.
+ """
+ if not os.path.isdir(save_dir):
+ return None
+
+ pattern = os.path.join(save_dir, "checkpoint_step*.pt")
+ files = glob.glob(pattern)
+ if not files:
+ return None
+
+ # Extract step numbers and find max
+ step_re = re.compile(r"checkpoint_step(\d+)\.pt$")
+ best_step = -1
+ best_path = None
+ for f in files:
+ m = step_re.search(f)
+ if m:
+ step = int(m.group(1))
+ if step > best_step:
+ best_step = step
+ best_path = f
+
+ return best_path
+
+
def load_checkpoint(
path: str,
predictor: nn.Module,