summaryrefslogtreecommitdiff
path: root/src/training/schedulers.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/schedulers.py')
-rw-r--r--src/training/schedulers.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/src/training/schedulers.py b/src/training/schedulers.py
new file mode 100644
index 0000000..7cda3b4
--- /dev/null
+++ b/src/training/schedulers.py
@@ -0,0 +1,35 @@
+"""Schedule functions for temperature (τ), sparsity (λ), and learning rate.
+
+All schedules are deterministic functions of the current step.
+See CLAUDE.md §3.1 for exact formulas.
+"""
+
+from __future__ import annotations
+
+import math
+
+
+def tau_schedule(step: int, total_steps: int, tau_init: float, tau_final: float) -> float:
+ """Cosine annealing for Gumbel-Sigmoid temperature.
+
+ τ(t) = τ_f + 0.5(τ_i - τ_f)(1 + cos(πt/T))
+
+ Starts at tau_init, ends at tau_final.
+ """
+ if total_steps <= 0:
+ return tau_final
+ progress = min(step / total_steps, 1.0)
+ return tau_final + 0.5 * (tau_init - tau_final) * (1 + math.cos(math.pi * progress))
+
+
+def lambda_schedule(step: int, total_steps: int, lambda_max: float, warmup_frac: float = 0.2) -> float:
+ """Linear ramp for sparsity coefficient.
+
+ Ramps linearly from 0 to lambda_max over first warmup_frac of training.
+ """
+ if lambda_max == 0.0:
+ return 0.0
+ warmup_steps = int(total_steps * warmup_frac)
+ if warmup_steps <= 0:
+ return lambda_max
+ return lambda_max * min(step / warmup_steps, 1.0)