diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-10 09:50:33 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-10 09:50:33 -0600 |
| commit | 039c12d3cf7178db6a7d80b02cf022d67231014e (patch) | |
| tree | b3104310bfaced0d992729f59f1a7ef2e769c6bd /src/training/trainer.py | |
| parent | 80579d6cc254d337a23e71404ae7ecab1849d1e5 (diff) | |
Add auto-resume checkpointing, S1/S2 configs, and experiment results
- Auto-resume: find latest checkpoint in save_dir on startup
- SIGUSR1 handler: save checkpoint before SLURM timeout
- S1 config (constant tau=5, identity init verification)
- S2 config (constant tau=2, gradient flow check)
- Experiment results tracker with S0/S1 data
- Speed estimates and experiment plan
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src/training/trainer.py')
| -rw-r--r-- | src/training/trainer.py | 34 |
1 files changed, 29 insertions, 5 deletions
diff --git a/src/training/trainer.py b/src/training/trainer.py index 7ebd21e..d157d0c 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -8,6 +8,7 @@ from __future__ import annotations import math import os +import signal import warnings from dataclasses import dataclass, field from typing import Any, Optional @@ -22,7 +23,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from src.data.dolma import build_eval_dataloader, build_train_dataloader from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A from src.model.predictor import StructurePredictor -from src.training.checkpointing import load_checkpoint, save_checkpoint +from src.training.checkpointing import find_latest_checkpoint, load_checkpoint, save_checkpoint from src.training.schedulers import lambda_schedule, tau_schedule from src.utils.logging import finish_wandb, init_wandb, log_metrics from src.utils.topology import compute_topology_metrics @@ -136,16 +137,22 @@ class Trainer: self.best_eval_nll = float("inf") self.collapse_counter = 0 # consecutive steps with collapsed A - # Resume from checkpoint if specified - if config.resume_from: + # Resume from checkpoint: explicit path or auto-find latest + resume_path = config.resume_from + if not resume_path: + resume_path = find_latest_checkpoint(config.save_dir) + if resume_path and self.is_main: + print(f"Auto-resume: found {resume_path}") + + if resume_path: state = load_checkpoint( - config.resume_from, + resume_path, self.predictor, self.optimizer, self.lr_scheduler, device=self.device, ) - self.global_step = state["step"] + self.global_step = state["step"] + 1 # resume from NEXT step self.best_eval_nll = state["best_eval_nll"] if self.is_main: print(f"Resumed from step {self.global_step}") @@ -261,11 +268,28 @@ class Trainer: else: self.wandb_run = None + def _save_on_signal(self, signum: int, frame: Any) -> None: + """Save checkpoint when receiving SIGUSR1 (SLURM pre-timeout signal).""" + if self.is_main: + print(f"\nReceived signal {signum}, saving checkpoint before exit...") + save_checkpoint( + self.config.save_dir, + self.global_step, + self.predictor, + self.optimizer, + self.lr_scheduler, + self.best_eval_nll, + ) + raise SystemExit(0) + def train(self) -> None: """Main training loop.""" config = self.config train_iter = iter(self.train_loader) + # Register signal handler for graceful SLURM preemption + signal.signal(signal.SIGUSR1, self._save_on_signal) + if self.is_main: print(f"\nStarting training: {config.total_steps} steps") print(f" batch_size={config.batch_size}, micro_batch={config.micro_batch_size}, accum={self.accum_steps}") |
