summaryrefslogtreecommitdiff
path: root/ep_run/stiefel_feedback.py
blob: a5c86ca86678756b14ea2f79d4c92f108472bcc7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Factored feedback subspace: fixed compressor C + per-layer Stiefel expander U_l.

δ_l = α_l · c_t @ U_l^T,  where c_t = e_L @ C^T

C ∈ R^{r×V}: fixed row-orthonormal compressor (CC^T = I_r)
U_l ∈ St(d_l, r): per-layer learnable orthonormal expander
α_l > 0: per-layer scalar gain

U_l updated via Riemannian gradient on Stiefel with EMA + QR retraction.
α_l updated via correlation-based least-squares, not norm-ratio.
"""
import torch
import torch.nn as nn


@torch.no_grad()
def init_row_orthonormal_C(vocab_size: int, rank: int, device=None, dtype=torch.float32):
    """C ∈ R^{r×V} with CC^T = I_r."""
    g = torch.randn(vocab_size, rank, device=device, dtype=dtype)
    q, _ = torch.linalg.qr(g, mode="reduced")  # (V, r)
    return q.T.contiguous()  # (r, V)


class StiefelFeedbackLayer(nn.Module):
    """Per-layer factored feedback: δ = α · c @ U^T where U ∈ St(d, r)."""

    def __init__(self, d: int, r: int):
        super().__init__()
        # U on Stiefel: (d, r), orthonormal columns
        U_init = torch.linalg.qr(torch.randn(d, r), mode="reduced")[0]
        self.register_buffer("U", U_init)
        self.register_buffer("alpha", torch.tensor(0.1))
        self.register_buffer("ema_G", torch.zeros(d, r))

    def compute_delta(self, c: torch.Tensor) -> torch.Tensor:
        """c: (B, T, r) → δ: (B, T, d)"""
        return self.alpha * (c @ self.U.T)

    @torch.no_grad()
    def update(self, g_hat: torch.Tensor, c: torch.Tensor,
               eta_B: float = 3e-5, tau: float = 0.99,
               beta_alpha: float = 0.01, eps: float = 1e-8,
               alpha_min: float = 1e-4, alpha_max: float = 10.0,
               max_step_frob: float = 1.0, frozen: bool = False):
        """Update U and alpha from local signal g_hat and compressed error c.

        g_hat: (B, T, d) — local proxy signal (e.g. reconstruction error)
        c: (B, T, r) — compressed global error
        frozen: if True, only accumulate ema_G, don't update U or alpha
        """
        B, T, d = g_hat.shape
        r = c.shape[-1]
        N = B * T

        g_flat = g_hat.reshape(N, d)
        c_flat = c.reshape(N, r)

        # Cross-covariance G = (1/N) g_hat^T @ c
        G = (g_flat.T @ c_flat) / max(N, 1)  # (d, r)

        # EMA
        self.ema_G.mul_(tau).add_(G, alpha=1.0 - tau)

        if frozen:
            return {"G": G, "alpha": self.alpha.clone(), "frozen": True}

        # Tangent projection on Stiefel
        UtG = self.U.T @ self.ema_G  # (r, r)
        sym = 0.5 * (UtG + UtG.T)
        Delta = self.ema_G - self.U @ sym  # (d, r)

        # Step clipping
        delta_norm = torch.linalg.norm(Delta, ord="fro")
        if max_step_frob is not None and delta_norm > max_step_frob:
            Delta = Delta * (max_step_frob / (delta_norm + eps))

        # Riemannian step + QR retraction
        U_tilde = self.U + (eta_B * self.alpha) * Delta
        Q, R = torch.linalg.qr(U_tilde, mode="reduced")
        # Sign fix: make diag(R) positive
        s = torch.sign(torch.diag(R))
        s = torch.where(s == 0, torch.ones_like(s), s)
        self.U.copy_(Q * s.unsqueeze(0))

        # Correlation-based alpha update: α* = <G, U> / (mean ||c||^2 + eps)
        c2_mean = c_flat.square().sum() / max(N, 1)
        alpha_star = (self.U * G).sum() / (c2_mean + eps)
        alpha_star = alpha_star.clamp(min=alpha_min, max=alpha_max)
        self.alpha.mul_(1.0 - beta_alpha).add_(alpha_star, alpha=beta_alpha)

        return {
            "G": G,
            "Delta_frob": delta_norm.item(),
            "alpha_star": alpha_star.item(),
            "alpha": self.alpha.item(),
            "rho": (G * self.U).sum().item() / (torch.linalg.norm(G, ord="fro").item() * (r ** 0.5) + eps),
        }


class StiefelFeedbackSystem(nn.Module):
    """Full feedback system: global C + per-layer StiefelFeedbackLayer."""

    def __init__(self, vocab_size: int, layer_dims: list[int], rank: int = 128):
        super().__init__()
        self.rank = min(rank, vocab_size)  # can't compress to more dims than vocab
        self.register_buffer("C", init_row_orthonormal_C(vocab_size, self.rank))
        self.layers = nn.ModuleList([
            StiefelFeedbackLayer(d, self.rank) for d in layer_dims
        ])

    def compress_error(self, e_L: torch.Tensor) -> torch.Tensor:
        """e_L: (B, T, V) → c: (B, T, r)"""
        return e_L @ self.C.T  # (B, T, r)

    def compute_deltas(self, c: torch.Tensor) -> list[torch.Tensor]:
        """Compute per-layer feedback δ_l from compressed error c."""
        return [layer.compute_delta(c) for layer in self.layers]

    def update_all(self, g_hats: list[torch.Tensor], c: torch.Tensor,
                   frozen: bool = False, **kwargs) -> list[dict]:
        """Update all layers' U and alpha."""
        diagnostics = []
        for layer, g_hat in zip(self.layers, g_hats):
            diag = layer.update(g_hat, c, frozen=frozen, **kwargs)
            diagnostics.append(diag)
        return diagnostics