diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-10 09:50:33 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-10 09:50:33 -0600 |
| commit | 039c12d3cf7178db6a7d80b02cf022d67231014e (patch) | |
| tree | b3104310bfaced0d992729f59f1a7ef2e769c6bd /src/training/checkpointing.py | |
| parent | 80579d6cc254d337a23e71404ae7ecab1849d1e5 (diff) | |
Add auto-resume checkpointing, S1/S2 configs, and experiment results
- Auto-resume: find latest checkpoint in save_dir on startup
- SIGUSR1 handler: save checkpoint before SLURM timeout
- S1 config (constant tau=5, identity init verification)
- S2 config (constant tau=2, gradient flow check)
- Experiment results tracker with S0/S1 data
- Speed estimates and experiment plan
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src/training/checkpointing.py')
| -rw-r--r-- | src/training/checkpointing.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/src/training/checkpointing.py b/src/training/checkpointing.py index 9ff02df..b53ce4f 100644 --- a/src/training/checkpointing.py +++ b/src/training/checkpointing.py @@ -6,7 +6,9 @@ Frozen models (OLMo, Qwen) are not checkpointed — they load from HuggingFace. from __future__ import annotations +import glob import os +import re from typing import Any, Optional import torch @@ -55,6 +57,35 @@ def save_checkpoint( return path +def find_latest_checkpoint(save_dir: str) -> Optional[str]: + """Find the latest checkpoint in save_dir by step number. + + Returns: + Path to latest checkpoint, or None if no checkpoints found. + """ + if not os.path.isdir(save_dir): + return None + + pattern = os.path.join(save_dir, "checkpoint_step*.pt") + files = glob.glob(pattern) + if not files: + return None + + # Extract step numbers and find max + step_re = re.compile(r"checkpoint_step(\d+)\.pt$") + best_step = -1 + best_path = None + for f in files: + m = step_re.search(f) + if m: + step = int(m.group(1)) + if step > best_step: + best_step = step + best_path = f + + return best_path + + def load_checkpoint( path: str, predictor: nn.Module, |
