summaryrefslogtreecommitdiff
path: root/ep_run/local_layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/local_layers.py')
-rw-r--r--ep_run/local_layers.py305
1 files changed, 305 insertions, 0 deletions
diff --git a/ep_run/local_layers.py b/ep_run/local_layers.py
new file mode 100644
index 0000000..db73fb8
--- /dev/null
+++ b/ep_run/local_layers.py
@@ -0,0 +1,305 @@
+"""Local-learning variants of Linear supporting BP / FA / DFA / sign-sym methods.
+
+LocalLinear is a drop-in replacement for nn.Linear that selects its backward
+computation based on the `method` argument:
+
+ bp: standard autograd (nn.Linear behavior)
+ fa: custom autograd, backward uses a fixed random matrix B in place of W.T
+ (Lillicrap-style Feedback Alignment, per projection)
+ sign_sym: custom autograd, backward uses sign(W) in place of W.T (Xiao 2018)
+ dfa: forward uses normal autograd (so upstream params like embeddings /
+ LayerNorm still get BP gradients). Input is cached during forward.
+ After loss.backward(), call `apply_dfa_update(model, e_L)` to
+ OVERWRITE LocalLinear .grad with DFA-computed update. LocalLinear
+ weights thus receive direct projection updates while non-LocalLinear
+ params (embeddings, LN) retain BP gradients (pragmatic hybrid).
+
+For DFA, call `initialize_dfa_targets(model, target_dim)` once after model
+construction, then each training step:
+ 1. standard forward
+ 2. compute loss, loss.backward() (fills BP .grad on everything)
+ 3. compute e_L = dL/dlogits analytically (for LM: softmax(logits)-onehot)
+ 4. call `apply_dfa_update(model, e_L)` to overwrite LocalLinear .grad
+ 5. optimizer.step()
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LinearFA(torch.autograd.Function):
+ """Linear forward, FA backward (replace W.T with fixed random B)."""
+
+ @staticmethod
+ def forward(ctx, x, W, B, bias):
+ ctx.save_for_backward(x, W, B)
+ ctx.has_bias = bias is not None
+ out = x @ W.t()
+ if bias is not None:
+ out = out + bias
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ x, W, B = ctx.saved_tensors
+ # True BP would use W here (shape out x in). FA replaces with random B of same shape.
+ grad_x = grad_out @ B
+ # grad_W is standard outer product of grad_out and x (summed over leading dims)
+ grad_W = grad_out.reshape(-1, grad_out.shape[-1]).t() @ x.reshape(-1, x.shape[-1])
+ grad_B = None # B is fixed random
+ grad_bias = None
+ if ctx.has_bias:
+ grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0)
+ return grad_x, grad_W, grad_B, grad_bias
+
+
+class LinearSignSym(torch.autograd.Function):
+ """Linear forward, sign-symmetric backward with rescaling.
+
+ B = sign(W) · ||W||_F / sqrt(numel(W))
+ The rescale matches sign(W)'s magnitude to W's typical element magnitude,
+ avoiding the 50x gradient blowup that pure sign(W) caused.
+ """
+
+ @staticmethod
+ def forward(ctx, x, W, bias):
+ ctx.save_for_backward(x, W)
+ ctx.has_bias = bias is not None
+ out = x @ W.t()
+ if bias is not None:
+ out = out + bias
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ x, W = ctx.saved_tensors
+ # Rescaled sign: scale so that ||B||_F ≈ ||W||_F
+ scale = W.norm() / (W.numel() ** 0.5 + 1e-8)
+ sign_W_scaled = torch.sign(W) * scale
+ grad_x = grad_out @ sign_W_scaled
+ grad_W = grad_out.reshape(-1, grad_out.shape[-1]).t() @ x.reshape(-1, x.shape[-1])
+ grad_bias = None
+ if ctx.has_bias:
+ grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0)
+ return grad_x, grad_W, grad_bias
+
+
+class LocalLinear(nn.Module):
+ """nn.Linear drop-in with method-dispatched backward (bp/fa/dfa/sign_sym/dfa_block).
+
+ fa_init_mode (only used when method='fa'):
+ gaussian: B ~ N(0, init_std) (Lillicrap default, current)
+ orthogonal: B = Haar orthogonal × scale (#1: JL-isometric, scaled to match BP grad norm)
+ ortho_he: B = Haar orthogonal × sqrt(2/out) (#2: He-init for backward signal)
+ sparse: B with k non-zeros per row, signs ±1, scaled (#4: structured sparse)
+
+ fa_grape (only with method='fa'): if True, B is updated each step via cosine alignment
+ to the rank-1 JVP Jacobian estimate Ĵ = (W p) p^T. Implements GrAPE
+ (Caillon et al., ICLR 2026) per-layer. Forward only — no W^T transport.
+ """
+
+ def __init__(self, in_features, out_features, bias=False, method="bp", init_std=0.02,
+ fa_init_mode="gaussian", fa_sparse_k=None, fa_grape=False, fa_grape_n_probe=32):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.method = method
+ self._fa_grape = (method == "fa") and fa_grape
+ self._fa_grape_n_probe = fa_grape_n_probe
+
+ self.weight = nn.Parameter(torch.empty(out_features, in_features))
+ nn.init.normal_(self.weight, mean=0.0, std=init_std)
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_features))
+ else:
+ self.register_parameter("bias", None)
+
+ if method == "fa":
+ B = torch.empty(out_features, in_features)
+ if fa_init_mode == "gaussian":
+ nn.init.normal_(B, mean=0.0, std=init_std)
+ elif fa_init_mode == "orthogonal":
+ # Haar orthogonal (semi-orthogonal for non-square), scaled to match BP grad norm.
+ # BP backward: grad_x = grad_out @ W has norm ~ sqrt(in) * std(W) * ||grad_out||.
+ # Pure orthogonal preserves norm (|| || stays). Scale to match BP's natural shrinkage.
+ nn.init.orthogonal_(B)
+ scale = (in_features ** 0.5) * init_std
+ B.mul_(scale)
+ elif fa_init_mode == "ortho_he":
+ # He init for backward: variance = 2/out_features (matches ReLU-friendly backward)
+ nn.init.orthogonal_(B)
+ scale = (2.0 / out_features) ** 0.5
+ B.mul_(scale)
+ elif fa_init_mode == "sparse":
+ # k non-zero entries per row, signs ±1, scaled so row L2 norm matches Gaussian B
+ k = fa_sparse_k if fa_sparse_k is not None else max(1, in_features // 16)
+ B.zero_()
+ for i in range(out_features):
+ idx = torch.randperm(in_features)[:k]
+ signs = (torch.randint(0, 2, (k,)).float() * 2 - 1)
+ B[i, idx] = signs
+ # Scale to match Gaussian B's row variance: row_var(Gaussian) = in_features * init_std^2
+ # row_var(sparse) = k (after scale 1). To match: scale = init_std * sqrt(in_features/k)
+ B.mul_(init_std * (in_features / k) ** 0.5)
+ else:
+ raise ValueError(f"Unknown fa_init_mode: {fa_init_mode}")
+ # GrAPE: B is updated via JVP-cosine alignment (not via standard optimizer);
+ # store as Parameter with requires_grad=False so we can update in-place.
+ if self._fa_grape:
+ self.B = nn.Parameter(B, requires_grad=False)
+ else:
+ self.register_buffer("B", B)
+
+ if method == "dfa":
+ # B_dfa shape (out_features, target_dim); set via initialize_dfa_targets
+ self.register_buffer("B_dfa", None)
+ self._dfa_cached_input = None
+
+ if method == "dfa_block":
+ # B_dfa_block: (out_features, d_block) — projects block-output-error to layer output
+ # Set via initialize_dfa_block_targets
+ self.register_buffer("B_dfa_block", None)
+ self._dfa_block_cached_input = None
+
+ def set_dfa_target_dim(self, target_dim, init_std=0.02):
+ assert self.method == "dfa"
+ B = torch.empty(self.out_features, target_dim, device=self.weight.device)
+ nn.init.normal_(B, mean=0.0, std=init_std)
+ self.B_dfa = B
+
+ @torch.no_grad()
+ def grape_align_step(self, lr_b=0.01, normalize_columns=False):
+ """GrAPE: update B toward rank-1 JVP Jacobian estimate via cosine alignment.
+
+ For linear y = W x: J = W. JVP at random p: g = W p. Estimate Ĵ = (1/N) Σ g_i p_i^T → W.
+ Forward only (uses W in forward computation, no W^T transport).
+
+ normalize_columns: paper (Eq. 6) does column-normalize B. We default off because
+ our per-linear FA needs magnitude match (B → W in magnitude AND direction);
+ column-norm prevents matching W's row magnitudes.
+ """
+ if not self._fa_grape:
+ return
+ N = self._fa_grape_n_probe
+ device, dtype = self.weight.device, self.weight.dtype
+ # Random Gaussian perturbation p ~ N(0, I) in input space
+ p = torch.randn(N, self.in_features, device=device, dtype=dtype)
+ # Forward: g = p @ W^T (one matrix multiply)
+ g = F.linear(p, self.weight) # (N, out_features)
+ # Rank-1 Jacobian estimate: Ĵ = (1/N) g^T @ p, shape (out, in)
+ J_hat = (g.t() @ p) / N
+ # Cosine alignment gradient
+ B_norm = self.B.norm()
+ J_norm = J_hat.norm()
+ if B_norm < 1e-8 or J_norm < 1e-8:
+ return
+ cos_val = (self.B * J_hat).sum() / (B_norm * J_norm)
+ # ∂(1 - cos)/∂B = -Ĵ/(||B||·||Ĵ||) + cos · B/||B||²
+ grad = -J_hat / (B_norm * J_norm) + cos_val * self.B / (B_norm ** 2)
+ self.B.add_(grad, alpha=-lr_b)
+ if normalize_columns:
+ col_norms = self.B.norm(dim=0, keepdim=True).clamp_min(1e-8)
+ self.B.div_(col_norms)
+
+ def forward(self, x):
+ if self.method == "bp":
+ return F.linear(x, self.weight, self.bias)
+ if self.method == "fa":
+ return LinearFA.apply(x, self.weight, self.B, self.bias)
+ if self.method == "sign_sym":
+ return LinearSignSym.apply(x, self.weight, self.bias)
+ if self.method == "dfa":
+ # Cache input for later manual DFA update (will overwrite BP .grad)
+ self._dfa_cached_input = x.detach()
+ return F.linear(x, self.weight, self.bias)
+ if self.method == "dfa_block":
+ self._dfa_block_cached_input = x.detach()
+ return F.linear(x, self.weight, self.bias)
+ raise ValueError(f"Unknown method: {self.method}")
+
+ def dfa_compute_grad(self, e_L):
+ """Set self.weight.grad and self.bias.grad from global error e_L.
+
+ e_L shape (..., target_dim). delta = e_L @ B_dfa.T, shape (..., out_features).
+ ΔW = sum_n delta_n outer input_n, where inputs are cached from forward.
+ """
+ assert self.method == "dfa"
+ assert self._dfa_cached_input is not None, "DFA forward not called or cache cleared"
+ assert self.B_dfa is not None, "DFA target_dim not set (call initialize_dfa_targets)"
+
+ delta = e_L @ self.B_dfa.t() # (..., out_features)
+ delta_flat = delta.reshape(-1, self.out_features)
+ inp_flat = self._dfa_cached_input.reshape(-1, self.in_features)
+
+ grad_W = delta_flat.t() @ inp_flat # (out_features, in_features)
+
+ if self.weight.grad is None:
+ self.weight.grad = grad_W.clone()
+ else:
+ self.weight.grad.copy_(grad_W)
+
+ if self.bias is not None:
+ grad_b = delta_flat.sum(dim=0)
+ if self.bias.grad is None:
+ self.bias.grad = grad_b.clone()
+ else:
+ self.bias.grad.copy_(grad_b)
+
+ self._dfa_cached_input = None
+
+ def extra_repr(self):
+ return f"in={self.in_features}, out={self.out_features}, method={self.method}"
+
+
+def initialize_dfa_targets(model, target_dim):
+ """Must be called once after model construction and device placement, for DFA mode."""
+ for module in model.modules():
+ if isinstance(module, LocalLinear) and module.method == "dfa":
+ module.set_dfa_target_dim(target_dim)
+
+
+def apply_dfa_update(model, e_L):
+ """Iterate over all LocalLinear(dfa) modules and populate their .grad from e_L."""
+ for module in model.modules():
+ if isinstance(module, LocalLinear) and module.method == "dfa":
+ module.dfa_compute_grad(e_L)
+
+
+def initialize_dfa_block_targets(model, d_block, init_std=0.02):
+ """For dfa_block mode: each LocalLinear gets a random B_dfa_block of shape (out, d_block)."""
+ for module in model.modules():
+ if isinstance(module, LocalLinear) and module.method == "dfa_block":
+ B = torch.empty(module.out_features, d_block, device=module.weight.device)
+ nn.init.normal_(B, mean=0.0, std=init_std)
+ module.B_dfa_block = B
+
+
+def apply_dfa_block_update(block, block_output_error):
+ """Apply DFA-within-block updates to all LocalLinear(dfa_block) in `block`.
+
+ block_output_error: (B, T, d_block) — gradient at the block's output.
+ Each linear's grad: ΔW = (block_output_error @ B_dfa_block.T)^T @ cached_input
+ """
+ err = block_output_error.detach()
+ err_flat = err.reshape(-1, err.size(-1)) # (N, d_block)
+ N = err_flat.size(0)
+ for module in block.modules():
+ if isinstance(module, LocalLinear) and module.method == "dfa_block":
+ assert module.B_dfa_block is not None, "Call initialize_dfa_block_targets first"
+ assert module._dfa_block_cached_input is not None, "Forward not called"
+ # delta: (N, out_features) = err_flat @ B_dfa_block.T
+ delta = err_flat @ module.B_dfa_block.t()
+ inp_flat = module._dfa_block_cached_input.reshape(-1, module.in_features)
+ grad_W = (delta.t() @ inp_flat) / max(N, 1)
+ if module.weight.grad is None:
+ module.weight.grad = grad_W.clone()
+ else:
+ module.weight.grad.copy_(grad_W)
+ if module.bias is not None:
+ grad_b = delta.sum(dim=0) / max(N, 1)
+ if module.bias.grad is None:
+ module.bias.grad = grad_b.clone()
+ else:
+ module.bias.grad.copy_(grad_b)
+ module._dfa_block_cached_input = None