diff options
| -rw-r--r-- | models/srm/srm_aol_v1.py | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/models/srm/srm_aol_v1.py b/models/srm/srm_aol_v1.py index 01584df..c4e2719 100644 --- a/models/srm/srm_aol_v1.py +++ b/models/srm/srm_aol_v1.py @@ -43,18 +43,25 @@ from models.sparse_embedding import CastedSparseEmbedding # ============================================================================= -# 1-Lipschitz primitives -# Lipschitz computations done in float32 (then cast back) so the bound stays -# exact under bf16 forward dtype. +# Approximately 1-Lipschitz primitives +# Normalization (AOL) and orthogonalization (Cayley) are computed in float32, +# then cast to forward dtype (bf16 by default). The bound is *exact in fp32* +# but only approximate after cast — bf16 rounding introduces a small error +# that accumulates over n_aol_layers matmuls. Empirically the margin to the +# theoretical κ-bound is large (~5×), so this is fine in practice, but the +# guarantee is not strict. For applications where strictness matters, run +# the bounded operators in float32. # ============================================================================= class AOLLinear(nn.Module): - """1-Lipschitz linear layer via AOL (Prach & Lampert 2022) rescaling. + """≤1-Lipschitz linear layer via AOL (Prach & Lampert 2022) rescaling. Given W ∈ R^(out × in), let A = W^T W (symmetric PSD). Define D_jj = 1 / √(Σ_i |A_ij| + eps); set W̃ = W · diag(D). - Then ||W̃ x||_2 ≤ ||x||_2 for all x (per Prach & Lampert Theorem 1). - Bias is unconstrained — shift only, doesn't affect Lipschitz w.r.t. input. + Then ||W̃ x||_2 ≤ ||x||_2 in float32 (Prach & Lampert Theorem 1). + Bound is approximate (not exact) under bf16 due to rounding in W·diag(D) + and the subsequent matmul. Bias is unconstrained (shift only, doesn't + affect Lipschitz w.r.t. input). """ def __init__(self, in_dim: int, out_dim: int, bias: bool = True, cast_to: torch.dtype = torch.bfloat16, eps: float = 1e-6): @@ -102,10 +109,14 @@ class AOLBlock(nn.Module): class CayleyOrthogonal(nn.Module): - """Orthogonal matrix Q ∈ R^(d × d) via Cayley transform. - - Q = (I - S)(I + S)^(-1) where S = (A - A^T)/2 is skew-symmetric ⇒ Q^T Q = I. - Solve done in float32 for numerical stability. + """Approximately orthogonal Q ∈ R^(d × d) via Cayley transform. + + Q = (I + S)^{-1}(I - S) where S = (A - A^T)/2 is skew-symmetric. + Since (I+S) and (I-S) commute (both polynomials in S), the form is also + Q = (I - S)(I + S)^{-1}. Q^T Q = I exactly in float32 — approximate + after cast to bf16. Solve done in float32 for numerical stability. + NOTE: torch.linalg.solve may not be fullgraph-compile friendly. Test + before enabling torch.compile / FSDP. """ def __init__(self, dim: int, cast_to: torch.dtype = torch.bfloat16): super().__init__() @@ -129,11 +140,12 @@ class BlockGain(nn.Module): L row entries (P-normalized): [(1/√η) · a_LH, a_LL], sum = κ Parameterized via softmax × κ ⇒ exact equality (saturation). """ - def __init__(self, kappa: float = 0.9, eta: float = 1.0, init_diag: float = 1.0): + def __init__(self, kappa: float = 0.9, eta: float = 1.0, init_diag: float = 3.0): super().__init__() self.kappa = kappa self.eta = eta - # Initialize softmax favoring diagonal (a_HH, a_LL): minimal cross coupling at start + # init_diag=3.0 → softmax([3, 0]) ≈ [0.953, 0.047] ⇒ ~5% cross-coupling at start + # (init_diag=1.0 was too weak — gave 27% cross-coupling; init_diag=3.0 truly minimal) self.logits_H = nn.Parameter(torch.tensor([init_diag, 0.0])) self.logits_L = nn.Parameter(torch.tensor([0.0, init_diag])) |
