diff options
Diffstat (limited to 'ep_run/stiefel_feedback.py')
| -rw-r--r-- | ep_run/stiefel_feedback.py | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/ep_run/stiefel_feedback.py b/ep_run/stiefel_feedback.py new file mode 100644 index 0000000..a5c86ca --- /dev/null +++ b/ep_run/stiefel_feedback.py @@ -0,0 +1,126 @@ +"""Factored feedback subspace: fixed compressor C + per-layer Stiefel expander U_l. + +δ_l = α_l · c_t @ U_l^T, where c_t = e_L @ C^T + +C ∈ R^{r×V}: fixed row-orthonormal compressor (CC^T = I_r) +U_l ∈ St(d_l, r): per-layer learnable orthonormal expander +α_l > 0: per-layer scalar gain + +U_l updated via Riemannian gradient on Stiefel with EMA + QR retraction. +α_l updated via correlation-based least-squares, not norm-ratio. +""" +import torch +import torch.nn as nn + + +@torch.no_grad() +def init_row_orthonormal_C(vocab_size: int, rank: int, device=None, dtype=torch.float32): + """C ∈ R^{r×V} with CC^T = I_r.""" + g = torch.randn(vocab_size, rank, device=device, dtype=dtype) + q, _ = torch.linalg.qr(g, mode="reduced") # (V, r) + return q.T.contiguous() # (r, V) + + +class StiefelFeedbackLayer(nn.Module): + """Per-layer factored feedback: δ = α · c @ U^T where U ∈ St(d, r).""" + + def __init__(self, d: int, r: int): + super().__init__() + # U on Stiefel: (d, r), orthonormal columns + U_init = torch.linalg.qr(torch.randn(d, r), mode="reduced")[0] + self.register_buffer("U", U_init) + self.register_buffer("alpha", torch.tensor(0.1)) + self.register_buffer("ema_G", torch.zeros(d, r)) + + def compute_delta(self, c: torch.Tensor) -> torch.Tensor: + """c: (B, T, r) → δ: (B, T, d)""" + return self.alpha * (c @ self.U.T) + + @torch.no_grad() + def update(self, g_hat: torch.Tensor, c: torch.Tensor, + eta_B: float = 3e-5, tau: float = 0.99, + beta_alpha: float = 0.01, eps: float = 1e-8, + alpha_min: float = 1e-4, alpha_max: float = 10.0, + max_step_frob: float = 1.0, frozen: bool = False): + """Update U and alpha from local signal g_hat and compressed error c. + + g_hat: (B, T, d) — local proxy signal (e.g. reconstruction error) + c: (B, T, r) — compressed global error + frozen: if True, only accumulate ema_G, don't update U or alpha + """ + B, T, d = g_hat.shape + r = c.shape[-1] + N = B * T + + g_flat = g_hat.reshape(N, d) + c_flat = c.reshape(N, r) + + # Cross-covariance G = (1/N) g_hat^T @ c + G = (g_flat.T @ c_flat) / max(N, 1) # (d, r) + + # EMA + self.ema_G.mul_(tau).add_(G, alpha=1.0 - tau) + + if frozen: + return {"G": G, "alpha": self.alpha.clone(), "frozen": True} + + # Tangent projection on Stiefel + UtG = self.U.T @ self.ema_G # (r, r) + sym = 0.5 * (UtG + UtG.T) + Delta = self.ema_G - self.U @ sym # (d, r) + + # Step clipping + delta_norm = torch.linalg.norm(Delta, ord="fro") + if max_step_frob is not None and delta_norm > max_step_frob: + Delta = Delta * (max_step_frob / (delta_norm + eps)) + + # Riemannian step + QR retraction + U_tilde = self.U + (eta_B * self.alpha) * Delta + Q, R = torch.linalg.qr(U_tilde, mode="reduced") + # Sign fix: make diag(R) positive + s = torch.sign(torch.diag(R)) + s = torch.where(s == 0, torch.ones_like(s), s) + self.U.copy_(Q * s.unsqueeze(0)) + + # Correlation-based alpha update: α* = <G, U> / (mean ||c||^2 + eps) + c2_mean = c_flat.square().sum() / max(N, 1) + alpha_star = (self.U * G).sum() / (c2_mean + eps) + alpha_star = alpha_star.clamp(min=alpha_min, max=alpha_max) + self.alpha.mul_(1.0 - beta_alpha).add_(alpha_star, alpha=beta_alpha) + + return { + "G": G, + "Delta_frob": delta_norm.item(), + "alpha_star": alpha_star.item(), + "alpha": self.alpha.item(), + "rho": (G * self.U).sum().item() / (torch.linalg.norm(G, ord="fro").item() * (r ** 0.5) + eps), + } + + +class StiefelFeedbackSystem(nn.Module): + """Full feedback system: global C + per-layer StiefelFeedbackLayer.""" + + def __init__(self, vocab_size: int, layer_dims: list[int], rank: int = 128): + super().__init__() + self.rank = min(rank, vocab_size) # can't compress to more dims than vocab + self.register_buffer("C", init_row_orthonormal_C(vocab_size, self.rank)) + self.layers = nn.ModuleList([ + StiefelFeedbackLayer(d, self.rank) for d in layer_dims + ]) + + def compress_error(self, e_L: torch.Tensor) -> torch.Tensor: + """e_L: (B, T, V) → c: (B, T, r)""" + return e_L @ self.C.T # (B, T, r) + + def compute_deltas(self, c: torch.Tensor) -> list[torch.Tensor]: + """Compute per-layer feedback δ_l from compressed error c.""" + return [layer.compute_delta(c) for layer in self.layers] + + def update_all(self, g_hats: list[torch.Tensor], c: torch.Tensor, + frozen: bool = False, **kwargs) -> list[dict]: + """Update all layers' U and alpha.""" + diagnostics = [] + for layer, g_hat in zip(self.layers, g_hats): + diag = layer.update(g_hat, c, frozen=frozen, **kwargs) + diagnostics.append(diag) + return diagnostics |
