summaryrefslogtreecommitdiff
path: root/src/training/schedulers.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
commit13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (patch)
tree073534138604c1c49021ca7e334322262129f6ac /src/training/schedulers.py
Initial implementation: DAGFormer Phase 1
- olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src/training/schedulers.py')
-rw-r--r--src/training/schedulers.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/src/training/schedulers.py b/src/training/schedulers.py
new file mode 100644
index 0000000..7cda3b4
--- /dev/null
+++ b/src/training/schedulers.py
@@ -0,0 +1,35 @@
+"""Schedule functions for temperature (τ), sparsity (λ), and learning rate.
+
+All schedules are deterministic functions of the current step.
+See CLAUDE.md §3.1 for exact formulas.
+"""
+
+from __future__ import annotations
+
+import math
+
+
+def tau_schedule(step: int, total_steps: int, tau_init: float, tau_final: float) -> float:
+ """Cosine annealing for Gumbel-Sigmoid temperature.
+
+ τ(t) = τ_f + 0.5(τ_i - τ_f)(1 + cos(πt/T))
+
+ Starts at tau_init, ends at tau_final.
+ """
+ if total_steps <= 0:
+ return tau_final
+ progress = min(step / total_steps, 1.0)
+ return tau_final + 0.5 * (tau_init - tau_final) * (1 + math.cos(math.pi * progress))
+
+
+def lambda_schedule(step: int, total_steps: int, lambda_max: float, warmup_frac: float = 0.2) -> float:
+ """Linear ramp for sparsity coefficient.
+
+ Ramps linearly from 0 to lambda_max over first warmup_frac of training.
+ """
+ if lambda_max == 0.0:
+ return 0.0
+ warmup_steps = int(total_steps * warmup_frac)
+ if warmup_steps <= 0:
+ return lambda_max
+ return lambda_max * min(step / warmup_steps, 1.0)