"""Local-learning variants of Linear supporting BP / FA / DFA / sign-sym methods. LocalLinear is a drop-in replacement for nn.Linear that selects its backward computation based on the `method` argument: bp: standard autograd (nn.Linear behavior) fa: custom autograd, backward uses a fixed random matrix B in place of W.T (Lillicrap-style Feedback Alignment, per projection) sign_sym: custom autograd, backward uses sign(W) in place of W.T (Xiao 2018) dfa: forward uses normal autograd (so upstream params like embeddings / LayerNorm still get BP gradients). Input is cached during forward. After loss.backward(), call `apply_dfa_update(model, e_L)` to OVERWRITE LocalLinear .grad with DFA-computed update. LocalLinear weights thus receive direct projection updates while non-LocalLinear params (embeddings, LN) retain BP gradients (pragmatic hybrid). For DFA, call `initialize_dfa_targets(model, target_dim)` once after model construction, then each training step: 1. standard forward 2. compute loss, loss.backward() (fills BP .grad on everything) 3. compute e_L = dL/dlogits analytically (for LM: softmax(logits)-onehot) 4. call `apply_dfa_update(model, e_L)` to overwrite LocalLinear .grad 5. optimizer.step() """ import torch import torch.nn as nn import torch.nn.functional as F class LinearFA(torch.autograd.Function): """Linear forward, FA backward (replace W.T with fixed random B).""" @staticmethod def forward(ctx, x, W, B, bias): ctx.save_for_backward(x, W, B) ctx.has_bias = bias is not None out = x @ W.t() if bias is not None: out = out + bias return out @staticmethod def backward(ctx, grad_out): x, W, B = ctx.saved_tensors # True BP would use W here (shape out x in). FA replaces with random B of same shape. grad_x = grad_out @ B # grad_W is standard outer product of grad_out and x (summed over leading dims) grad_W = grad_out.reshape(-1, grad_out.shape[-1]).t() @ x.reshape(-1, x.shape[-1]) grad_B = None # B is fixed random grad_bias = None if ctx.has_bias: grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0) return grad_x, grad_W, grad_B, grad_bias class LinearSignSym(torch.autograd.Function): """Linear forward, sign-symmetric backward with rescaling. B = sign(W) · ||W||_F / sqrt(numel(W)) The rescale matches sign(W)'s magnitude to W's typical element magnitude, avoiding the 50x gradient blowup that pure sign(W) caused. """ @staticmethod def forward(ctx, x, W, bias): ctx.save_for_backward(x, W) ctx.has_bias = bias is not None out = x @ W.t() if bias is not None: out = out + bias return out @staticmethod def backward(ctx, grad_out): x, W = ctx.saved_tensors # Rescaled sign: scale so that ||B||_F ≈ ||W||_F scale = W.norm() / (W.numel() ** 0.5 + 1e-8) sign_W_scaled = torch.sign(W) * scale grad_x = grad_out @ sign_W_scaled grad_W = grad_out.reshape(-1, grad_out.shape[-1]).t() @ x.reshape(-1, x.shape[-1]) grad_bias = None if ctx.has_bias: grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0) return grad_x, grad_W, grad_bias class LocalLinear(nn.Module): """nn.Linear drop-in with method-dispatched backward (bp/fa/dfa/sign_sym/dfa_block). fa_init_mode (only used when method='fa'): gaussian: B ~ N(0, init_std) (Lillicrap default, current) orthogonal: B = Haar orthogonal × scale (#1: JL-isometric, scaled to match BP grad norm) ortho_he: B = Haar orthogonal × sqrt(2/out) (#2: He-init for backward signal) sparse: B with k non-zeros per row, signs ±1, scaled (#4: structured sparse) fa_grape (only with method='fa'): if True, B is updated each step via cosine alignment to the rank-1 JVP Jacobian estimate Ĵ = (W p) p^T. Implements GrAPE (Caillon et al., ICLR 2026) per-layer. Forward only — no W^T transport. """ def __init__(self, in_features, out_features, bias=False, method="bp", init_std=0.02, fa_init_mode="gaussian", fa_sparse_k=None, fa_grape=False, fa_grape_n_probe=32): super().__init__() self.in_features = in_features self.out_features = out_features self.method = method self._fa_grape = (method == "fa") and fa_grape self._fa_grape_n_probe = fa_grape_n_probe self.weight = nn.Parameter(torch.empty(out_features, in_features)) nn.init.normal_(self.weight, mean=0.0, std=init_std) if bias: self.bias = nn.Parameter(torch.zeros(out_features)) else: self.register_parameter("bias", None) if method == "fa": B = torch.empty(out_features, in_features) if fa_init_mode == "gaussian": nn.init.normal_(B, mean=0.0, std=init_std) elif fa_init_mode == "orthogonal": # Haar orthogonal (semi-orthogonal for non-square), scaled to match BP grad norm. # BP backward: grad_x = grad_out @ W has norm ~ sqrt(in) * std(W) * ||grad_out||. # Pure orthogonal preserves norm (|| || stays). Scale to match BP's natural shrinkage. nn.init.orthogonal_(B) scale = (in_features ** 0.5) * init_std B.mul_(scale) elif fa_init_mode == "ortho_he": # He init for backward: variance = 2/out_features (matches ReLU-friendly backward) nn.init.orthogonal_(B) scale = (2.0 / out_features) ** 0.5 B.mul_(scale) elif fa_init_mode == "sparse": # k non-zero entries per row, signs ±1, scaled so row L2 norm matches Gaussian B k = fa_sparse_k if fa_sparse_k is not None else max(1, in_features // 16) B.zero_() for i in range(out_features): idx = torch.randperm(in_features)[:k] signs = (torch.randint(0, 2, (k,)).float() * 2 - 1) B[i, idx] = signs # Scale to match Gaussian B's row variance: row_var(Gaussian) = in_features * init_std^2 # row_var(sparse) = k (after scale 1). To match: scale = init_std * sqrt(in_features/k) B.mul_(init_std * (in_features / k) ** 0.5) else: raise ValueError(f"Unknown fa_init_mode: {fa_init_mode}") # GrAPE: B is updated via JVP-cosine alignment (not via standard optimizer); # store as Parameter with requires_grad=False so we can update in-place. if self._fa_grape: self.B = nn.Parameter(B, requires_grad=False) else: self.register_buffer("B", B) if method == "dfa": # B_dfa shape (out_features, target_dim); set via initialize_dfa_targets self.register_buffer("B_dfa", None) self._dfa_cached_input = None if method == "dfa_block": # B_dfa_block: (out_features, d_block) — projects block-output-error to layer output # Set via initialize_dfa_block_targets self.register_buffer("B_dfa_block", None) self._dfa_block_cached_input = None def set_dfa_target_dim(self, target_dim, init_std=0.02): assert self.method == "dfa" B = torch.empty(self.out_features, target_dim, device=self.weight.device) nn.init.normal_(B, mean=0.0, std=init_std) self.B_dfa = B @torch.no_grad() def grape_align_step(self, lr_b=0.01, normalize_columns=False): """GrAPE: update B toward rank-1 JVP Jacobian estimate via cosine alignment. For linear y = W x: J = W. JVP at random p: g = W p. Estimate Ĵ = (1/N) Σ g_i p_i^T → W. Forward only (uses W in forward computation, no W^T transport). normalize_columns: paper (Eq. 6) does column-normalize B. We default off because our per-linear FA needs magnitude match (B → W in magnitude AND direction); column-norm prevents matching W's row magnitudes. """ if not self._fa_grape: return N = self._fa_grape_n_probe device, dtype = self.weight.device, self.weight.dtype # Random Gaussian perturbation p ~ N(0, I) in input space p = torch.randn(N, self.in_features, device=device, dtype=dtype) # Forward: g = p @ W^T (one matrix multiply) g = F.linear(p, self.weight) # (N, out_features) # Rank-1 Jacobian estimate: Ĵ = (1/N) g^T @ p, shape (out, in) J_hat = (g.t() @ p) / N # Cosine alignment gradient B_norm = self.B.norm() J_norm = J_hat.norm() if B_norm < 1e-8 or J_norm < 1e-8: return cos_val = (self.B * J_hat).sum() / (B_norm * J_norm) # ∂(1 - cos)/∂B = -Ĵ/(||B||·||Ĵ||) + cos · B/||B||² grad = -J_hat / (B_norm * J_norm) + cos_val * self.B / (B_norm ** 2) self.B.add_(grad, alpha=-lr_b) if normalize_columns: col_norms = self.B.norm(dim=0, keepdim=True).clamp_min(1e-8) self.B.div_(col_norms) def forward(self, x): if self.method == "bp": return F.linear(x, self.weight, self.bias) if self.method == "fa": return LinearFA.apply(x, self.weight, self.B, self.bias) if self.method == "sign_sym": return LinearSignSym.apply(x, self.weight, self.bias) if self.method == "dfa": # Cache input for later manual DFA update (will overwrite BP .grad) self._dfa_cached_input = x.detach() return F.linear(x, self.weight, self.bias) if self.method == "dfa_block": self._dfa_block_cached_input = x.detach() return F.linear(x, self.weight, self.bias) raise ValueError(f"Unknown method: {self.method}") def dfa_compute_grad(self, e_L): """Set self.weight.grad and self.bias.grad from global error e_L. e_L shape (..., target_dim). delta = e_L @ B_dfa.T, shape (..., out_features). ΔW = sum_n delta_n outer input_n, where inputs are cached from forward. """ assert self.method == "dfa" assert self._dfa_cached_input is not None, "DFA forward not called or cache cleared" assert self.B_dfa is not None, "DFA target_dim not set (call initialize_dfa_targets)" delta = e_L @ self.B_dfa.t() # (..., out_features) delta_flat = delta.reshape(-1, self.out_features) inp_flat = self._dfa_cached_input.reshape(-1, self.in_features) grad_W = delta_flat.t() @ inp_flat # (out_features, in_features) if self.weight.grad is None: self.weight.grad = grad_W.clone() else: self.weight.grad.copy_(grad_W) if self.bias is not None: grad_b = delta_flat.sum(dim=0) if self.bias.grad is None: self.bias.grad = grad_b.clone() else: self.bias.grad.copy_(grad_b) self._dfa_cached_input = None def extra_repr(self): return f"in={self.in_features}, out={self.out_features}, method={self.method}" def initialize_dfa_targets(model, target_dim): """Must be called once after model construction and device placement, for DFA mode.""" for module in model.modules(): if isinstance(module, LocalLinear) and module.method == "dfa": module.set_dfa_target_dim(target_dim) def apply_dfa_update(model, e_L): """Iterate over all LocalLinear(dfa) modules and populate their .grad from e_L.""" for module in model.modules(): if isinstance(module, LocalLinear) and module.method == "dfa": module.dfa_compute_grad(e_L) def initialize_dfa_block_targets(model, d_block, init_std=0.02): """For dfa_block mode: each LocalLinear gets a random B_dfa_block of shape (out, d_block).""" for module in model.modules(): if isinstance(module, LocalLinear) and module.method == "dfa_block": B = torch.empty(module.out_features, d_block, device=module.weight.device) nn.init.normal_(B, mean=0.0, std=init_std) module.B_dfa_block = B def apply_dfa_block_update(block, block_output_error): """Apply DFA-within-block updates to all LocalLinear(dfa_block) in `block`. block_output_error: (B, T, d_block) — gradient at the block's output. Each linear's grad: ΔW = (block_output_error @ B_dfa_block.T)^T @ cached_input """ err = block_output_error.detach() err_flat = err.reshape(-1, err.size(-1)) # (N, d_block) N = err_flat.size(0) for module in block.modules(): if isinstance(module, LocalLinear) and module.method == "dfa_block": assert module.B_dfa_block is not None, "Call initialize_dfa_block_targets first" assert module._dfa_block_cached_input is not None, "Forward not called" # delta: (N, out_features) = err_flat @ B_dfa_block.T delta = err_flat @ module.B_dfa_block.t() inp_flat = module._dfa_block_cached_input.reshape(-1, module.in_features) grad_W = (delta.t() @ inp_flat) / max(N, 1) if module.weight.grad is None: module.weight.grad = grad_W.clone() else: module.weight.grad.copy_(grad_W) if module.bias is not None: grad_b = delta.sum(dim=0) / max(N, 1) if module.bias.grad is None: module.bias.grad = grad_b.clone() else: module.bias.grad.copy_(grad_b) module._dfa_block_cached_input = None