From 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:00:39 -0600 Subject: Initial implementation: DAGFormer Phase 1 - olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 --- src/training/schedulers.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 src/training/schedulers.py (limited to 'src/training/schedulers.py') diff --git a/src/training/schedulers.py b/src/training/schedulers.py new file mode 100644 index 0000000..7cda3b4 --- /dev/null +++ b/src/training/schedulers.py @@ -0,0 +1,35 @@ +"""Schedule functions for temperature (τ), sparsity (λ), and learning rate. + +All schedules are deterministic functions of the current step. +See CLAUDE.md §3.1 for exact formulas. +""" + +from __future__ import annotations + +import math + + +def tau_schedule(step: int, total_steps: int, tau_init: float, tau_final: float) -> float: + """Cosine annealing for Gumbel-Sigmoid temperature. + + τ(t) = τ_f + 0.5(τ_i - τ_f)(1 + cos(πt/T)) + + Starts at tau_init, ends at tau_final. + """ + if total_steps <= 0: + return tau_final + progress = min(step / total_steps, 1.0) + return tau_final + 0.5 * (tau_init - tau_final) * (1 + math.cos(math.pi * progress)) + + +def lambda_schedule(step: int, total_steps: int, lambda_max: float, warmup_frac: float = 0.2) -> float: + """Linear ramp for sparsity coefficient. + + Ramps linearly from 0 to lambda_max over first warmup_frac of training. + """ + if lambda_max == 0.0: + return 0.0 + warmup_steps = int(total_steps * warmup_frac) + if warmup_steps <= 0: + return lambda_max + return lambda_max * min(step / warmup_steps, 1.0) -- cgit v1.2.3