diff options
Diffstat (limited to 'src/training/schedulers.py')
| -rw-r--r-- | src/training/schedulers.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/src/training/schedulers.py b/src/training/schedulers.py new file mode 100644 index 0000000..7cda3b4 --- /dev/null +++ b/src/training/schedulers.py @@ -0,0 +1,35 @@ +"""Schedule functions for temperature (τ), sparsity (λ), and learning rate. + +All schedules are deterministic functions of the current step. +See CLAUDE.md §3.1 for exact formulas. +""" + +from __future__ import annotations + +import math + + +def tau_schedule(step: int, total_steps: int, tau_init: float, tau_final: float) -> float: + """Cosine annealing for Gumbel-Sigmoid temperature. + + τ(t) = τ_f + 0.5(τ_i - τ_f)(1 + cos(πt/T)) + + Starts at tau_init, ends at tau_final. + """ + if total_steps <= 0: + return tau_final + progress = min(step / total_steps, 1.0) + return tau_final + 0.5 * (tau_init - tau_final) * (1 + math.cos(math.pi * progress)) + + +def lambda_schedule(step: int, total_steps: int, lambda_max: float, warmup_frac: float = 0.2) -> float: + """Linear ramp for sparsity coefficient. + + Ramps linearly from 0 to lambda_max over first warmup_frac of training. + """ + if lambda_max == 0.0: + return 0.0 + warmup_steps = int(total_steps * warmup_frac) + if warmup_steps <= 0: + return lambda_max + return lambda_max * min(step / warmup_steps, 1.0) |
