summaryrefslogtreecommitdiff
path: root/src/training/schedulers.py
blob: 7cda3b4c95dde4e7f4c45171bcdb8abd5126224c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""Schedule functions for temperature (τ), sparsity (λ), and learning rate.

All schedules are deterministic functions of the current step.
See CLAUDE.md §3.1 for exact formulas.
"""

from __future__ import annotations

import math


def tau_schedule(step: int, total_steps: int, tau_init: float, tau_final: float) -> float:
    """Cosine annealing for Gumbel-Sigmoid temperature.

    τ(t) = τ_f + 0.5(τ_i - τ_f)(1 + cos(πt/T))

    Starts at tau_init, ends at tau_final.
    """
    if total_steps <= 0:
        return tau_final
    progress = min(step / total_steps, 1.0)
    return tau_final + 0.5 * (tau_init - tau_final) * (1 + math.cos(math.pi * progress))


def lambda_schedule(step: int, total_steps: int, lambda_max: float, warmup_frac: float = 0.2) -> float:
    """Linear ramp for sparsity coefficient.

    Ramps linearly from 0 to lambda_max over first warmup_frac of training.
    """
    if lambda_max == 0.0:
        return 0.0
    warmup_steps = int(total_steps * warmup_frac)
    if warmup_steps <= 0:
        return lambda_max
    return lambda_max * min(step / warmup_steps, 1.0)