summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-10 09:50:33 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-10 09:50:33 -0600
commit039c12d3cf7178db6a7d80b02cf022d67231014e (patch)
treeb3104310bfaced0d992729f59f1a7ef2e769c6bd
parent80579d6cc254d337a23e71404ae7ecab1849d1e5 (diff)
Add auto-resume checkpointing, S1/S2 configs, and experiment results
- Auto-resume: find latest checkpoint in save_dir on startup - SIGUSR1 handler: save checkpoint before SLURM timeout - S1 config (constant tau=5, identity init verification) - S2 config (constant tau=2, gradient flow check) - Experiment results tracker with S0/S1 data - Speed estimates and experiment plan Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
-rw-r--r--configs/s1_identity_init.yaml51
-rw-r--r--configs/s2_gradient_flow.yaml52
-rw-r--r--experiment_plan.xlsxbin0 -> 11524 bytes
-rw-r--r--experiments/results.md225
-rw-r--r--scripts/slurm_s1.sh23
-rw-r--r--scripts/slurm_s2.sh23
-rw-r--r--scripts/slurm_train.sh4
-rw-r--r--src/training/checkpointing.py31
-rw-r--r--src/training/trainer.py34
9 files changed, 437 insertions, 6 deletions
diff --git a/configs/s1_identity_init.yaml b/configs/s1_identity_init.yaml
new file mode 100644
index 0000000..b68e147
--- /dev/null
+++ b/configs/s1_identity_init.yaml
@@ -0,0 +1,51 @@
+# S1 — Predictor identity init (constant tau=5, ~10M tokens)
+# Purpose: Verify init reproduces dense topology. NLL should match S0 (2.4569) within 1%.
+# Run: python scripts/train.py --config configs/s1_identity_init.yaml
+
+# Model
+olmo_model_id: "allenai/OLMo-2-0425-1B"
+qwen_model_id: "Qwen/Qwen3-Embedding-0.6B"
+
+# Predictor
+predictor_hidden_dim: 1024
+predictor_rank: 32
+cascading_gate_k: 5.0
+input_norm: "none"
+
+# Data
+dataset: "allenai/dolma"
+dataset_name: "v1_7"
+seq_len: 1024
+batch_size: 4
+micro_batch_size: 2
+qwen_input_prefix: ""
+
+# Eval
+eval_skip: 10000
+eval_size: 50
+
+# Training — ~10M tokens = 2500 steps @ batch=4, seq=1024
+total_steps: 2500
+lr: 3e-4
+weight_decay: 0.01
+optimizer: "adamw"
+
+# Schedules — constant tau=5 (no annealing), no sparsity
+tau_init: 5.0
+tau_final: 5.0
+tau_schedule: "cosine"
+lambda_max: 0.0
+lambda_warmup_frac: 0.2
+
+# Logging
+wandb_project: "dagformer"
+wandb_run_name: "s1-identity-init"
+log_every: 10
+eval_every: 100
+
+# Checkpointing
+save_every: 1000
+save_dir: "checkpoints/s1/"
+
+# Hardware
+num_gpus: 1
diff --git a/configs/s2_gradient_flow.yaml b/configs/s2_gradient_flow.yaml
new file mode 100644
index 0000000..fcd2724
--- /dev/null
+++ b/configs/s2_gradient_flow.yaml
@@ -0,0 +1,52 @@
+# S2 — Gradient flow check (constant tau=2, ~50M tokens)
+# Purpose: Lower tau gives sharper gates. Does predictor learn useful topology?
+# Any NLL drop below baseline (2.4569) = gradient flows correctly.
+# Run: python scripts/train.py --config configs/s2_gradient_flow.yaml
+
+# Model
+olmo_model_id: "allenai/OLMo-2-0425-1B"
+qwen_model_id: "Qwen/Qwen3-Embedding-0.6B"
+
+# Predictor
+predictor_hidden_dim: 1024
+predictor_rank: 32
+cascading_gate_k: 5.0
+input_norm: "none"
+
+# Data
+dataset: "allenai/dolma"
+dataset_name: "v1_7"
+seq_len: 1024
+batch_size: 4
+micro_batch_size: 2
+qwen_input_prefix: ""
+
+# Eval
+eval_skip: 10000
+eval_size: 50
+
+# Training — ~50M tokens = 12500 steps @ batch=4, seq=1024
+total_steps: 12500
+lr: 3e-4
+weight_decay: 0.01
+optimizer: "adamw"
+
+# Schedules — constant tau=2 (sharper gates than S1), no sparsity
+tau_init: 2.0
+tau_final: 2.0
+tau_schedule: "cosine"
+lambda_max: 0.0
+lambda_warmup_frac: 0.2
+
+# Logging
+wandb_project: "dagformer"
+wandb_run_name: "s2-gradient-flow"
+log_every: 10
+eval_every: 500
+
+# Checkpointing
+save_every: 2500
+save_dir: "checkpoints/s2/"
+
+# Hardware
+num_gpus: 1
diff --git a/experiment_plan.xlsx b/experiment_plan.xlsx
new file mode 100644
index 0000000..33a60b8
--- /dev/null
+++ b/experiment_plan.xlsx
Binary files differ
diff --git a/experiments/results.md b/experiments/results.md
new file mode 100644
index 0000000..d362f7e
--- /dev/null
+++ b/experiments/results.md
@@ -0,0 +1,225 @@
+# DAGFormer Experiment Results
+
+## Sanity Checks
+
+### S0 — Dense Baseline (no predictor)
+
+| Item | Value |
+|------|-------|
+| Status | **DONE** (from sanity training eval) |
+| Date | 2025-02-09 |
+| Job ID | 15785016 |
+| Hardware | A40×1 |
+| Eval set | skip=10000, size=50, seq_len=1024 |
+| **NLL_base** | **2.4569** |
+| Notes | All experiments must beat this. Consider re-running with eval_size=1000 for more robust estimate. |
+
+---
+
+### S1 — Predictor identity init (constant tau=5, ~10M tokens)
+
+| Item | Value |
+|------|-------|
+| Status | **DONE** |
+| Date | 2026-02-09 |
+| Job ID | 15788145 |
+| Config | r=32, tau=5→5 (constant), k=5, lambda=0 |
+| Tokens | ~10M (2500 steps @ batch=4, seq=1024) |
+| Hardware | A40×1 (gpub073) |
+| Wall time | ~2 hrs |
+| Target | NLL ≈ NLL_base (within 1%) |
+| Purpose | Verify init reproduces dense topology |
+| **Result** | **PASS** — NLL within 0.3% of baseline |
+
+| Metric | Value (final) |
+|--------|---------------|
+| eval/nll_soft | **2.4500** (baseline: 2.4569, diff: -0.3%) |
+| eval/nll_hard | **2.4506** (diff: -0.3%) |
+| eval/nll_baseline | 2.4569 |
+| topology/mean_A | 0.975 |
+| topology/seq_gate_frac | 0.986 |
+| topology/hyp_gate_frac | 0.988 |
+
+**Per-eval-step data:**
+
+| Step | nll_soft | nll_hard | nll_base | mean_A |
+|------|----------|----------|----------|--------|
+| 100 | 2.4531 | 2.4512 | 2.4569 | 0.970 |
+| 500 | 2.4588 | 2.4609 | 2.4569 | 0.974 |
+| 1000 | 2.4506 | 2.4506 | 2.4569 | 0.978 |
+| 1500 | 2.4562 | 2.4578 | 2.4569 | 0.972 |
+| 2000 | 2.4500 | 2.4506 | 2.4569 | 0.978 |
+| 2500 | 2.4500 | 2.4506 | 2.4569 | 0.975 |
+
+**Observations:**
+- Init NLL matches baseline from step 0 — identity init working correctly
+- Step 700 had transient dip (mean_A=0.916, nll_soft=2.496) but recovered — Gumbel noise exploration at high tau
+- nll_hard ≈ nll_soft throughout — at tau=5, soft gates ≈ 0.95, so hard threshold (>0) gives similar A
+
+---
+
+### S2 — Gradient flow check (constant tau=2, ~50M tokens)
+
+| Item | Value |
+|------|-------|
+| Status | **RUNNING** (attempt 2) |
+| Config | r=32, tau=2→2 (constant), k=5, lambda=0 |
+| Tokens | ~50M (12,500 steps @ batch=4, seq=1024) |
+| Hardware | A40×1 |
+| Est. Time | ~15 hrs (within 48h limit) |
+| Target | NLL < NLL_base (2.4569) |
+| Purpose | Lower tau gives sharper gates — does predictor learn useful topology? |
+
+**Attempt 1** — Job 15789537, crashed at step ~1860 (Dolma HTTP range request error)
+
+| Step | nll_soft | nll_hard | nll_baseline | mean_A |
+|------|----------|----------|--------------|--------|
+| 500 | 2.4581 | 2.4581 | 2.4569 | 0.993 |
+| 1000 | 2.4575 | 2.4569 | 2.4569 | 0.999 |
+| 1500 | 2.4547 | 2.4559 | 2.4569 | 0.993 |
+
+Observations (attempt 1):
+- Eval NLL ≈ baseline throughout — predictor still near init (mean_A ≈ 0.99)
+- Train NLL high variance (0.27–2.96) is normal batch-to-batch variation at batch_size=4
+- No checkpoint saved (save_every=2500, crashed at 1860)
+- Crashed due to Dolma streaming HTTP error, not code bug
+
+**Attempt 2** — Job 15798568 (fresh start, no checkpoint from attempt 1)
+
+| Metric | Value |
+|--------|-------|
+| eval/nll_soft | |
+| eval/nll_hard | |
+| topology/mean_A | |
+
+---
+
+## Phase 1 Core
+
+### P1 — Phase 1 default config (5B tokens)
+
+| Item | Value |
+|------|-------|
+| Status | NOT STARTED |
+| Config | r=32, tau=5→0.2 cosine, k=5, lambda=0→0.01 ramp |
+| Tokens | 5B |
+| Hardware | A40×4 |
+| Est. Time | ~4 days |
+
+| Metric | Value |
+|--------|-------|
+| eval/nll_soft | |
+| eval/nll_hard | |
+| topology/mean_A | |
+| topology/seq_gate_frac | |
+| topology/hyp_gate_frac | |
+
+---
+
+### P2 — Phase 1 extended (10B tokens)
+
+| Item | Value |
+|------|-------|
+| Status | NOT STARTED |
+| Config | Continue P1 if still improving at 5B |
+| Tokens | 10B |
+| Hardware | A40×4 |
+| Est. Time | ~7 days |
+
+---
+
+## Ablations
+
+### A1–A4: Rank r
+
+| ID | Rank | NLL_soft | NLL_hard | Sparsity | Notes |
+|----|------|----------|----------|----------|-------|
+| A1 | 8 | | | | |
+| A2 | 16 | | | | |
+| P1 | 32 | | | | (reference) |
+| A3 | 64 | | | | |
+| A4 | 256 | | | | full rank |
+
+### A5–A7: Temperature schedule
+
+| ID | tau_init | tau_final | NLL_soft | NLL_hard | A entropy | Notes |
+|----|----------|-----------|----------|----------|-----------|-------|
+| A5 | 1 | 1 | | | | constant, perpetually soft |
+| P1 | 5 | 0.2 | | | | (reference) |
+| A6 | 5 | 0.05 | | | | aggressive anneal |
+| A7 | 10 | 1.0 | | | | slow anneal |
+
+### A8–A9: Sparsity lambda
+
+| ID | lambda | NLL_soft | NLL_hard | Density | Notes |
+|----|--------|----------|----------|---------|-------|
+| A8 | 0 | | | | no sparsity |
+| P1 | 0→0.01 | | | | (reference) |
+| A9 | 0→0.05 | | | | high sparsity |
+
+### A10–A11: Cascading gate
+
+| ID | Gate | NLL_soft | NLL_hard | Dead heads | Notes |
+|----|------|----------|----------|------------|-------|
+| A10 | OFF | | | | |
+| P1 | k=5 fixed | | | | (reference) |
+| A11 | k=5 learnable | | | | |
+
+---
+
+## Analysis Experiments
+
+### X1 — Topology variance analysis
+| Item | Value |
+|------|-------|
+| Status | NOT STARTED |
+| Result | |
+
+### X2 — Domain-specific topology
+| Item | Value |
+|------|-------|
+| Status | NOT STARTED |
+| Result | |
+
+### X3 — Topology-NLL sensitivity
+| Item | Value |
+|------|-------|
+| Status | NOT STARTED |
+| Result | |
+
+---
+
+## Speed Estimates (A40×1, batch=4, micro_batch=2, seq=1024)
+
+| Component | Time | Notes |
+|-----------|------|-------|
+| Training step | ~3s | Forward + backward + optimizer |
+| Eval round (50 samples) | ~2 min | 25 batches × 3 modes (soft/hard/baseline) |
+| Model loading | ~10 min | OLMo + Qwen + eval set build |
+| 1K steps (no eval) | ~50 min | |
+| 1K steps (eval every 100) | ~70 min | 10 eval rounds add ~20 min |
+| 10K steps | ~12 hrs | |
+| 100K steps | ~5 days | Exceeds 48h SLURM limit, needs auto-resume |
+
+**Previous 14s/step estimate was wrong** — it included model loading and eval overhead in wall-clock average.
+
+---
+
+## Preliminary Data (from sanity training job 15785016)
+
+Run with cascading gate bug (layer 0 not exempted). 500/1000 steps completed before timeout.
+
+| Step | train/nll | eval/nll_soft | eval/nll_hard | eval/nll_baseline | mean_A | tau |
+|------|-----------|---------------|---------------|-------------------|--------|-----|
+| 0 | 3.539 | — | — | — | 0.417 | 5.00 |
+| 100 | 2.750 | 2.635 | 4.744 | 2.457 | 0.416 | 4.88 |
+| 200 | 3.102 | 2.630 | 4.570 | 2.457 | 0.416 | 4.54 |
+| 300 | 2.844 | 2.621 | 4.680 | 2.457 | 0.418 | 4.01 |
+| 400 | 2.492 | 2.641 | 4.893 | 2.457 | 0.419 | 3.34 |
+| 500 | 1.805 | 2.639 | 4.503 | 2.457 | 0.419 | 2.60 |
+
+**Key observations:**
+- train/nll decreasing (3.54 → 1.80) but eval/nll_soft flat (~2.63) — overfitting or predictor not generalizing
+- eval/nll_hard very high (4.5-4.9) due to cascading gate layer 0 bug (now fixed in `80579d6`)
+- mean_A stable ~0.42 (= ~0.89 over valid entries), no collapse
+- Baseline NLL = 2.4569 confirmed correct after double-shift fix
diff --git a/scripts/slurm_s1.sh b/scripts/slurm_s1.sh
new file mode 100644
index 0000000..8a13e9b
--- /dev/null
+++ b/scripts/slurm_s1.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+#SBATCH --signal=SIGUSR1@120
+export HF_HOME=/projects/bfqt/users/yurenh2/hf_cache
+export TRANSFORMERS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/transformers
+export HF_HUB_CACHE=/projects/bfqt/users/yurenh2/hf_cache/hub
+export HF_DATASETS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/datasets
+export TOKENIZERS_PARALLELISM=false
+
+export PYTHONPATH=/projects/bfqt/users/yurenh2/ml-projects/DAGFormer:$PYTHONPATH
+export PATH=$HOME/.local/bin:$PATH
+
+cd /projects/bfqt/users/yurenh2/ml-projects/DAGFormer
+mkdir -p logs checkpoints/s1
+
+echo "=== Job Info ==="
+echo "Job ID: $SLURM_JOB_ID"
+echo "Node: $SLURM_NODELIST"
+echo "GPU: $(nvidia-smi --query-gpu=name,memory.total --format=csv,noheader)"
+echo ""
+
+echo "=== Starting S1: identity init training ==="
+echo " Auto-resume enabled: will pick up from latest checkpoint in checkpoints/s1/"
+python3 -u scripts/train.py --config configs/s1_identity_init.yaml
diff --git a/scripts/slurm_s2.sh b/scripts/slurm_s2.sh
new file mode 100644
index 0000000..fe00552
--- /dev/null
+++ b/scripts/slurm_s2.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+#SBATCH --signal=SIGUSR1@120
+export HF_HOME=/projects/bfqt/users/yurenh2/hf_cache
+export TRANSFORMERS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/transformers
+export HF_HUB_CACHE=/projects/bfqt/users/yurenh2/hf_cache/hub
+export HF_DATASETS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/datasets
+export TOKENIZERS_PARALLELISM=false
+
+export PYTHONPATH=/projects/bfqt/users/yurenh2/ml-projects/DAGFormer:$PYTHONPATH
+export PATH=$HOME/.local/bin:$PATH
+
+cd /projects/bfqt/users/yurenh2/ml-projects/DAGFormer
+mkdir -p logs checkpoints/s2
+
+echo "=== Job Info ==="
+echo "Job ID: $SLURM_JOB_ID"
+echo "Node: $SLURM_NODELIST"
+echo "GPU: $(nvidia-smi --query-gpu=name,memory.total --format=csv,noheader)"
+echo ""
+
+echo "=== Starting S2: gradient flow check ==="
+echo " Auto-resume enabled: will pick up from latest checkpoint in checkpoints/s2/"
+python3 -u scripts/train.py --config configs/s2_gradient_flow.yaml
diff --git a/scripts/slurm_train.sh b/scripts/slurm_train.sh
index e1df687..361ba94 100644
--- a/scripts/slurm_train.sh
+++ b/scripts/slurm_train.sh
@@ -1,4 +1,5 @@
#!/bin/bash
+#SBATCH --signal=SIGUSR1@120
export HF_HOME=/projects/bfqt/users/yurenh2/hf_cache
export TRANSFORMERS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/transformers
export HF_HUB_CACHE=/projects/bfqt/users/yurenh2/hf_cache/hub
@@ -18,4 +19,5 @@ echo "GPU: $(nvidia-smi --query-gpu=name,memory.total --format=csv,noheader)"
echo ""
echo "=== Starting training ==="
-python3 -u scripts/train.py --config configs/sanity_check.yaml
+echo " Auto-resume enabled: will pick up from latest checkpoint"
+python3 -u scripts/train.py --config ${CONFIG:-configs/sanity_check.yaml}
diff --git a/src/training/checkpointing.py b/src/training/checkpointing.py
index 9ff02df..b53ce4f 100644
--- a/src/training/checkpointing.py
+++ b/src/training/checkpointing.py
@@ -6,7 +6,9 @@ Frozen models (OLMo, Qwen) are not checkpointed — they load from HuggingFace.
from __future__ import annotations
+import glob
import os
+import re
from typing import Any, Optional
import torch
@@ -55,6 +57,35 @@ def save_checkpoint(
return path
+def find_latest_checkpoint(save_dir: str) -> Optional[str]:
+ """Find the latest checkpoint in save_dir by step number.
+
+ Returns:
+ Path to latest checkpoint, or None if no checkpoints found.
+ """
+ if not os.path.isdir(save_dir):
+ return None
+
+ pattern = os.path.join(save_dir, "checkpoint_step*.pt")
+ files = glob.glob(pattern)
+ if not files:
+ return None
+
+ # Extract step numbers and find max
+ step_re = re.compile(r"checkpoint_step(\d+)\.pt$")
+ best_step = -1
+ best_path = None
+ for f in files:
+ m = step_re.search(f)
+ if m:
+ step = int(m.group(1))
+ if step > best_step:
+ best_step = step
+ best_path = f
+
+ return best_path
+
+
def load_checkpoint(
path: str,
predictor: nn.Module,
diff --git a/src/training/trainer.py b/src/training/trainer.py
index 7ebd21e..d157d0c 100644
--- a/src/training/trainer.py
+++ b/src/training/trainer.py
@@ -8,6 +8,7 @@ from __future__ import annotations
import math
import os
+import signal
import warnings
from dataclasses import dataclass, field
from typing import Any, Optional
@@ -22,7 +23,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from src.data.dolma import build_eval_dataloader, build_train_dataloader
from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A
from src.model.predictor import StructurePredictor
-from src.training.checkpointing import load_checkpoint, save_checkpoint
+from src.training.checkpointing import find_latest_checkpoint, load_checkpoint, save_checkpoint
from src.training.schedulers import lambda_schedule, tau_schedule
from src.utils.logging import finish_wandb, init_wandb, log_metrics
from src.utils.topology import compute_topology_metrics
@@ -136,16 +137,22 @@ class Trainer:
self.best_eval_nll = float("inf")
self.collapse_counter = 0 # consecutive steps with collapsed A
- # Resume from checkpoint if specified
- if config.resume_from:
+ # Resume from checkpoint: explicit path or auto-find latest
+ resume_path = config.resume_from
+ if not resume_path:
+ resume_path = find_latest_checkpoint(config.save_dir)
+ if resume_path and self.is_main:
+ print(f"Auto-resume: found {resume_path}")
+
+ if resume_path:
state = load_checkpoint(
- config.resume_from,
+ resume_path,
self.predictor,
self.optimizer,
self.lr_scheduler,
device=self.device,
)
- self.global_step = state["step"]
+ self.global_step = state["step"] + 1 # resume from NEXT step
self.best_eval_nll = state["best_eval_nll"]
if self.is_main:
print(f"Resumed from step {self.global_step}")
@@ -261,11 +268,28 @@ class Trainer:
else:
self.wandb_run = None
+ def _save_on_signal(self, signum: int, frame: Any) -> None:
+ """Save checkpoint when receiving SIGUSR1 (SLURM pre-timeout signal)."""
+ if self.is_main:
+ print(f"\nReceived signal {signum}, saving checkpoint before exit...")
+ save_checkpoint(
+ self.config.save_dir,
+ self.global_step,
+ self.predictor,
+ self.optimizer,
+ self.lr_scheduler,
+ self.best_eval_nll,
+ )
+ raise SystemExit(0)
+
def train(self) -> None:
"""Main training loop."""
config = self.config
train_iter = iter(self.train_loader)
+ # Register signal handler for graceful SLURM preemption
+ signal.signal(signal.SIGUSR1, self._save_on_signal)
+
if self.is_main:
print(f"\nStarting training: {config.total_steps} steps")
print(f" batch_size={config.batch_size}, micro_batch={config.micro_batch_size}, accum={self.accum_steps}")