summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
Diffstat (limited to 'src/training')
-rw-r--r--src/training/checkpointing.py31
-rw-r--r--src/training/trainer.py34
2 files changed, 60 insertions, 5 deletions
diff --git a/src/training/checkpointing.py b/src/training/checkpointing.py
index 9ff02df..b53ce4f 100644
--- a/src/training/checkpointing.py
+++ b/src/training/checkpointing.py
@@ -6,7 +6,9 @@ Frozen models (OLMo, Qwen) are not checkpointed — they load from HuggingFace.
from __future__ import annotations
+import glob
import os
+import re
from typing import Any, Optional
import torch
@@ -55,6 +57,35 @@ def save_checkpoint(
return path
+def find_latest_checkpoint(save_dir: str) -> Optional[str]:
+ """Find the latest checkpoint in save_dir by step number.
+
+ Returns:
+ Path to latest checkpoint, or None if no checkpoints found.
+ """
+ if not os.path.isdir(save_dir):
+ return None
+
+ pattern = os.path.join(save_dir, "checkpoint_step*.pt")
+ files = glob.glob(pattern)
+ if not files:
+ return None
+
+ # Extract step numbers and find max
+ step_re = re.compile(r"checkpoint_step(\d+)\.pt$")
+ best_step = -1
+ best_path = None
+ for f in files:
+ m = step_re.search(f)
+ if m:
+ step = int(m.group(1))
+ if step > best_step:
+ best_step = step
+ best_path = f
+
+ return best_path
+
+
def load_checkpoint(
path: str,
predictor: nn.Module,
diff --git a/src/training/trainer.py b/src/training/trainer.py
index 7ebd21e..d157d0c 100644
--- a/src/training/trainer.py
+++ b/src/training/trainer.py
@@ -8,6 +8,7 @@ from __future__ import annotations
import math
import os
+import signal
import warnings
from dataclasses import dataclass, field
from typing import Any, Optional
@@ -22,7 +23,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from src.data.dolma import build_eval_dataloader, build_train_dataloader
from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A
from src.model.predictor import StructurePredictor
-from src.training.checkpointing import load_checkpoint, save_checkpoint
+from src.training.checkpointing import find_latest_checkpoint, load_checkpoint, save_checkpoint
from src.training.schedulers import lambda_schedule, tau_schedule
from src.utils.logging import finish_wandb, init_wandb, log_metrics
from src.utils.topology import compute_topology_metrics
@@ -136,16 +137,22 @@ class Trainer:
self.best_eval_nll = float("inf")
self.collapse_counter = 0 # consecutive steps with collapsed A
- # Resume from checkpoint if specified
- if config.resume_from:
+ # Resume from checkpoint: explicit path or auto-find latest
+ resume_path = config.resume_from
+ if not resume_path:
+ resume_path = find_latest_checkpoint(config.save_dir)
+ if resume_path and self.is_main:
+ print(f"Auto-resume: found {resume_path}")
+
+ if resume_path:
state = load_checkpoint(
- config.resume_from,
+ resume_path,
self.predictor,
self.optimizer,
self.lr_scheduler,
device=self.device,
)
- self.global_step = state["step"]
+ self.global_step = state["step"] + 1 # resume from NEXT step
self.best_eval_nll = state["best_eval_nll"]
if self.is_main:
print(f"Resumed from step {self.global_step}")
@@ -261,11 +268,28 @@ class Trainer:
else:
self.wandb_run = None
+ def _save_on_signal(self, signum: int, frame: Any) -> None:
+ """Save checkpoint when receiving SIGUSR1 (SLURM pre-timeout signal)."""
+ if self.is_main:
+ print(f"\nReceived signal {signum}, saving checkpoint before exit...")
+ save_checkpoint(
+ self.config.save_dir,
+ self.global_step,
+ self.predictor,
+ self.optimizer,
+ self.lr_scheduler,
+ self.best_eval_nll,
+ )
+ raise SystemExit(0)
+
def train(self) -> None:
"""Main training loop."""
config = self.config
train_iter = iter(self.train_loader)
+ # Register signal handler for graceful SLURM preemption
+ signal.signal(signal.SIGUSR1, self._save_on_signal)
+
if self.is_main:
print(f"\nStarting training: {config.total_steps} steps")
print(f" batch_size={config.batch_size}, micro_batch={config.micro_batch_size}, accum={self.accum_steps}")