summaryrefslogtreecommitdiff
path: root/models/srm/srm_aol_v1.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/srm/srm_aol_v1.py')
-rw-r--r--models/srm/srm_aol_v1.py36
1 files changed, 24 insertions, 12 deletions
diff --git a/models/srm/srm_aol_v1.py b/models/srm/srm_aol_v1.py
index 01584df..c4e2719 100644
--- a/models/srm/srm_aol_v1.py
+++ b/models/srm/srm_aol_v1.py
@@ -43,18 +43,25 @@ from models.sparse_embedding import CastedSparseEmbedding
# =============================================================================
-# 1-Lipschitz primitives
-# Lipschitz computations done in float32 (then cast back) so the bound stays
-# exact under bf16 forward dtype.
+# Approximately 1-Lipschitz primitives
+# Normalization (AOL) and orthogonalization (Cayley) are computed in float32,
+# then cast to forward dtype (bf16 by default). The bound is *exact in fp32*
+# but only approximate after cast — bf16 rounding introduces a small error
+# that accumulates over n_aol_layers matmuls. Empirically the margin to the
+# theoretical κ-bound is large (~5×), so this is fine in practice, but the
+# guarantee is not strict. For applications where strictness matters, run
+# the bounded operators in float32.
# =============================================================================
class AOLLinear(nn.Module):
- """1-Lipschitz linear layer via AOL (Prach & Lampert 2022) rescaling.
+ """≤1-Lipschitz linear layer via AOL (Prach & Lampert 2022) rescaling.
Given W ∈ R^(out × in), let A = W^T W (symmetric PSD).
Define D_jj = 1 / √(Σ_i |A_ij| + eps); set W̃ = W · diag(D).
- Then ||W̃ x||_2 ≤ ||x||_2 for all x (per Prach & Lampert Theorem 1).
- Bias is unconstrained — shift only, doesn't affect Lipschitz w.r.t. input.
+ Then ||W̃ x||_2 ≤ ||x||_2 in float32 (Prach & Lampert Theorem 1).
+ Bound is approximate (not exact) under bf16 due to rounding in W·diag(D)
+ and the subsequent matmul. Bias is unconstrained (shift only, doesn't
+ affect Lipschitz w.r.t. input).
"""
def __init__(self, in_dim: int, out_dim: int, bias: bool = True,
cast_to: torch.dtype = torch.bfloat16, eps: float = 1e-6):
@@ -102,10 +109,14 @@ class AOLBlock(nn.Module):
class CayleyOrthogonal(nn.Module):
- """Orthogonal matrix Q ∈ R^(d × d) via Cayley transform.
-
- Q = (I - S)(I + S)^(-1) where S = (A - A^T)/2 is skew-symmetric ⇒ Q^T Q = I.
- Solve done in float32 for numerical stability.
+ """Approximately orthogonal Q ∈ R^(d × d) via Cayley transform.
+
+ Q = (I + S)^{-1}(I - S) where S = (A - A^T)/2 is skew-symmetric.
+ Since (I+S) and (I-S) commute (both polynomials in S), the form is also
+ Q = (I - S)(I + S)^{-1}. Q^T Q = I exactly in float32 — approximate
+ after cast to bf16. Solve done in float32 for numerical stability.
+ NOTE: torch.linalg.solve may not be fullgraph-compile friendly. Test
+ before enabling torch.compile / FSDP.
"""
def __init__(self, dim: int, cast_to: torch.dtype = torch.bfloat16):
super().__init__()
@@ -129,11 +140,12 @@ class BlockGain(nn.Module):
L row entries (P-normalized): [(1/√η) · a_LH, a_LL], sum = κ
Parameterized via softmax × κ ⇒ exact equality (saturation).
"""
- def __init__(self, kappa: float = 0.9, eta: float = 1.0, init_diag: float = 1.0):
+ def __init__(self, kappa: float = 0.9, eta: float = 1.0, init_diag: float = 3.0):
super().__init__()
self.kappa = kappa
self.eta = eta
- # Initialize softmax favoring diagonal (a_HH, a_LL): minimal cross coupling at start
+ # init_diag=3.0 → softmax([3, 0]) ≈ [0.953, 0.047] ⇒ ~5% cross-coupling at start
+ # (init_diag=1.0 was too weak — gave 27% cross-coupling; init_diag=3.0 truly minimal)
self.logits_H = nn.Parameter(torch.tensor([init_diag, 0.0]))
self.logits_L = nn.Parameter(torch.tensor([0.0, init_diag]))