summaryrefslogtreecommitdiff
path: root/ep_run/local_layers.py
blob: db73fb8584a62ae4bc412bcc6516c2d3586b5671 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
"""Local-learning variants of Linear supporting BP / FA / DFA / sign-sym methods.

LocalLinear is a drop-in replacement for nn.Linear that selects its backward
computation based on the `method` argument:

  bp:       standard autograd (nn.Linear behavior)
  fa:       custom autograd, backward uses a fixed random matrix B in place of W.T
            (Lillicrap-style Feedback Alignment, per projection)
  sign_sym: custom autograd, backward uses sign(W) in place of W.T (Xiao 2018)
  dfa:      forward uses normal autograd (so upstream params like embeddings /
            LayerNorm still get BP gradients). Input is cached during forward.
            After loss.backward(), call `apply_dfa_update(model, e_L)` to
            OVERWRITE LocalLinear .grad with DFA-computed update. LocalLinear
            weights thus receive direct projection updates while non-LocalLinear
            params (embeddings, LN) retain BP gradients (pragmatic hybrid).

For DFA, call `initialize_dfa_targets(model, target_dim)` once after model
construction, then each training step:
  1. standard forward
  2. compute loss, loss.backward() (fills BP .grad on everything)
  3. compute e_L = dL/dlogits analytically (for LM: softmax(logits)-onehot)
  4. call `apply_dfa_update(model, e_L)` to overwrite LocalLinear .grad
  5. optimizer.step()
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class LinearFA(torch.autograd.Function):
    """Linear forward, FA backward (replace W.T with fixed random B)."""

    @staticmethod
    def forward(ctx, x, W, B, bias):
        ctx.save_for_backward(x, W, B)
        ctx.has_bias = bias is not None
        out = x @ W.t()
        if bias is not None:
            out = out + bias
        return out

    @staticmethod
    def backward(ctx, grad_out):
        x, W, B = ctx.saved_tensors
        # True BP would use W here (shape out x in). FA replaces with random B of same shape.
        grad_x = grad_out @ B
        # grad_W is standard outer product of grad_out and x (summed over leading dims)
        grad_W = grad_out.reshape(-1, grad_out.shape[-1]).t() @ x.reshape(-1, x.shape[-1])
        grad_B = None  # B is fixed random
        grad_bias = None
        if ctx.has_bias:
            grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0)
        return grad_x, grad_W, grad_B, grad_bias


class LinearSignSym(torch.autograd.Function):
    """Linear forward, sign-symmetric backward with rescaling.

    B = sign(W) · ||W||_F / sqrt(numel(W))
    The rescale matches sign(W)'s magnitude to W's typical element magnitude,
    avoiding the 50x gradient blowup that pure sign(W) caused.
    """

    @staticmethod
    def forward(ctx, x, W, bias):
        ctx.save_for_backward(x, W)
        ctx.has_bias = bias is not None
        out = x @ W.t()
        if bias is not None:
            out = out + bias
        return out

    @staticmethod
    def backward(ctx, grad_out):
        x, W = ctx.saved_tensors
        # Rescaled sign: scale so that ||B||_F ≈ ||W||_F
        scale = W.norm() / (W.numel() ** 0.5 + 1e-8)
        sign_W_scaled = torch.sign(W) * scale
        grad_x = grad_out @ sign_W_scaled
        grad_W = grad_out.reshape(-1, grad_out.shape[-1]).t() @ x.reshape(-1, x.shape[-1])
        grad_bias = None
        if ctx.has_bias:
            grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0)
        return grad_x, grad_W, grad_bias


class LocalLinear(nn.Module):
    """nn.Linear drop-in with method-dispatched backward (bp/fa/dfa/sign_sym/dfa_block).

    fa_init_mode (only used when method='fa'):
      gaussian:    B ~ N(0, init_std)             (Lillicrap default, current)
      orthogonal:  B = Haar orthogonal × scale    (#1: JL-isometric, scaled to match BP grad norm)
      ortho_he:    B = Haar orthogonal × sqrt(2/out)  (#2: He-init for backward signal)
      sparse:      B with k non-zeros per row, signs ±1, scaled (#4: structured sparse)

    fa_grape (only with method='fa'): if True, B is updated each step via cosine alignment
      to the rank-1 JVP Jacobian estimate Ĵ = (W p) p^T. Implements GrAPE
      (Caillon et al., ICLR 2026) per-layer. Forward only — no W^T transport.
    """

    def __init__(self, in_features, out_features, bias=False, method="bp", init_std=0.02,
                 fa_init_mode="gaussian", fa_sparse_k=None, fa_grape=False, fa_grape_n_probe=32):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.method = method
        self._fa_grape = (method == "fa") and fa_grape
        self._fa_grape_n_probe = fa_grape_n_probe

        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        nn.init.normal_(self.weight, mean=0.0, std=init_std)

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter("bias", None)

        if method == "fa":
            B = torch.empty(out_features, in_features)
            if fa_init_mode == "gaussian":
                nn.init.normal_(B, mean=0.0, std=init_std)
            elif fa_init_mode == "orthogonal":
                # Haar orthogonal (semi-orthogonal for non-square), scaled to match BP grad norm.
                # BP backward: grad_x = grad_out @ W has norm ~ sqrt(in) * std(W) * ||grad_out||.
                # Pure orthogonal preserves norm (|| || stays). Scale to match BP's natural shrinkage.
                nn.init.orthogonal_(B)
                scale = (in_features ** 0.5) * init_std
                B.mul_(scale)
            elif fa_init_mode == "ortho_he":
                # He init for backward: variance = 2/out_features (matches ReLU-friendly backward)
                nn.init.orthogonal_(B)
                scale = (2.0 / out_features) ** 0.5
                B.mul_(scale)
            elif fa_init_mode == "sparse":
                # k non-zero entries per row, signs ±1, scaled so row L2 norm matches Gaussian B
                k = fa_sparse_k if fa_sparse_k is not None else max(1, in_features // 16)
                B.zero_()
                for i in range(out_features):
                    idx = torch.randperm(in_features)[:k]
                    signs = (torch.randint(0, 2, (k,)).float() * 2 - 1)
                    B[i, idx] = signs
                # Scale to match Gaussian B's row variance: row_var(Gaussian) = in_features * init_std^2
                # row_var(sparse) = k (after scale 1). To match: scale = init_std * sqrt(in_features/k)
                B.mul_(init_std * (in_features / k) ** 0.5)
            else:
                raise ValueError(f"Unknown fa_init_mode: {fa_init_mode}")
            # GrAPE: B is updated via JVP-cosine alignment (not via standard optimizer);
            # store as Parameter with requires_grad=False so we can update in-place.
            if self._fa_grape:
                self.B = nn.Parameter(B, requires_grad=False)
            else:
                self.register_buffer("B", B)

        if method == "dfa":
            # B_dfa shape (out_features, target_dim); set via initialize_dfa_targets
            self.register_buffer("B_dfa", None)
            self._dfa_cached_input = None

        if method == "dfa_block":
            # B_dfa_block: (out_features, d_block) — projects block-output-error to layer output
            # Set via initialize_dfa_block_targets
            self.register_buffer("B_dfa_block", None)
            self._dfa_block_cached_input = None

    def set_dfa_target_dim(self, target_dim, init_std=0.02):
        assert self.method == "dfa"
        B = torch.empty(self.out_features, target_dim, device=self.weight.device)
        nn.init.normal_(B, mean=0.0, std=init_std)
        self.B_dfa = B

    @torch.no_grad()
    def grape_align_step(self, lr_b=0.01, normalize_columns=False):
        """GrAPE: update B toward rank-1 JVP Jacobian estimate via cosine alignment.

        For linear y = W x: J = W. JVP at random p: g = W p. Estimate Ĵ = (1/N) Σ g_i p_i^T → W.
        Forward only (uses W in forward computation, no W^T transport).

        normalize_columns: paper (Eq. 6) does column-normalize B. We default off because
          our per-linear FA needs magnitude match (B → W in magnitude AND direction);
          column-norm prevents matching W's row magnitudes.
        """
        if not self._fa_grape:
            return
        N = self._fa_grape_n_probe
        device, dtype = self.weight.device, self.weight.dtype
        # Random Gaussian perturbation p ~ N(0, I) in input space
        p = torch.randn(N, self.in_features, device=device, dtype=dtype)
        # Forward: g = p @ W^T (one matrix multiply)
        g = F.linear(p, self.weight)  # (N, out_features)
        # Rank-1 Jacobian estimate: Ĵ = (1/N) g^T @ p, shape (out, in)
        J_hat = (g.t() @ p) / N
        # Cosine alignment gradient
        B_norm = self.B.norm()
        J_norm = J_hat.norm()
        if B_norm < 1e-8 or J_norm < 1e-8:
            return
        cos_val = (self.B * J_hat).sum() / (B_norm * J_norm)
        # ∂(1 - cos)/∂B = -Ĵ/(||B||·||Ĵ||) + cos · B/||B||²
        grad = -J_hat / (B_norm * J_norm) + cos_val * self.B / (B_norm ** 2)
        self.B.add_(grad, alpha=-lr_b)
        if normalize_columns:
            col_norms = self.B.norm(dim=0, keepdim=True).clamp_min(1e-8)
            self.B.div_(col_norms)

    def forward(self, x):
        if self.method == "bp":
            return F.linear(x, self.weight, self.bias)
        if self.method == "fa":
            return LinearFA.apply(x, self.weight, self.B, self.bias)
        if self.method == "sign_sym":
            return LinearSignSym.apply(x, self.weight, self.bias)
        if self.method == "dfa":
            # Cache input for later manual DFA update (will overwrite BP .grad)
            self._dfa_cached_input = x.detach()
            return F.linear(x, self.weight, self.bias)
        if self.method == "dfa_block":
            self._dfa_block_cached_input = x.detach()
            return F.linear(x, self.weight, self.bias)
        raise ValueError(f"Unknown method: {self.method}")

    def dfa_compute_grad(self, e_L):
        """Set self.weight.grad and self.bias.grad from global error e_L.

        e_L shape (..., target_dim). delta = e_L @ B_dfa.T, shape (..., out_features).
        ΔW = sum_n delta_n outer input_n, where inputs are cached from forward.
        """
        assert self.method == "dfa"
        assert self._dfa_cached_input is not None, "DFA forward not called or cache cleared"
        assert self.B_dfa is not None, "DFA target_dim not set (call initialize_dfa_targets)"

        delta = e_L @ self.B_dfa.t()  # (..., out_features)
        delta_flat = delta.reshape(-1, self.out_features)
        inp_flat = self._dfa_cached_input.reshape(-1, self.in_features)

        grad_W = delta_flat.t() @ inp_flat  # (out_features, in_features)

        if self.weight.grad is None:
            self.weight.grad = grad_W.clone()
        else:
            self.weight.grad.copy_(grad_W)

        if self.bias is not None:
            grad_b = delta_flat.sum(dim=0)
            if self.bias.grad is None:
                self.bias.grad = grad_b.clone()
            else:
                self.bias.grad.copy_(grad_b)

        self._dfa_cached_input = None

    def extra_repr(self):
        return f"in={self.in_features}, out={self.out_features}, method={self.method}"


def initialize_dfa_targets(model, target_dim):
    """Must be called once after model construction and device placement, for DFA mode."""
    for module in model.modules():
        if isinstance(module, LocalLinear) and module.method == "dfa":
            module.set_dfa_target_dim(target_dim)


def apply_dfa_update(model, e_L):
    """Iterate over all LocalLinear(dfa) modules and populate their .grad from e_L."""
    for module in model.modules():
        if isinstance(module, LocalLinear) and module.method == "dfa":
            module.dfa_compute_grad(e_L)


def initialize_dfa_block_targets(model, d_block, init_std=0.02):
    """For dfa_block mode: each LocalLinear gets a random B_dfa_block of shape (out, d_block)."""
    for module in model.modules():
        if isinstance(module, LocalLinear) and module.method == "dfa_block":
            B = torch.empty(module.out_features, d_block, device=module.weight.device)
            nn.init.normal_(B, mean=0.0, std=init_std)
            module.B_dfa_block = B


def apply_dfa_block_update(block, block_output_error):
    """Apply DFA-within-block updates to all LocalLinear(dfa_block) in `block`.

    block_output_error: (B, T, d_block) — gradient at the block's output.
    Each linear's grad: ΔW = (block_output_error @ B_dfa_block.T)^T @ cached_input
    """
    err = block_output_error.detach()
    err_flat = err.reshape(-1, err.size(-1))  # (N, d_block)
    N = err_flat.size(0)
    for module in block.modules():
        if isinstance(module, LocalLinear) and module.method == "dfa_block":
            assert module.B_dfa_block is not None, "Call initialize_dfa_block_targets first"
            assert module._dfa_block_cached_input is not None, "Forward not called"
            # delta: (N, out_features) = err_flat @ B_dfa_block.T
            delta = err_flat @ module.B_dfa_block.t()
            inp_flat = module._dfa_block_cached_input.reshape(-1, module.in_features)
            grad_W = (delta.t() @ inp_flat) / max(N, 1)
            if module.weight.grad is None:
                module.weight.grad = grad_W.clone()
            else:
                module.weight.grad.copy_(grad_W)
            if module.bias is not None:
                grad_b = delta.sum(dim=0) / max(N, 1)
                if module.bias.grad is None:
                    module.bias.grad = grad_b.clone()
                else:
                    module.bias.grad.copy_(grad_b)
            module._dfa_block_cached_input = None