diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/local_layers.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/local_layers.py')
| -rw-r--r-- | ep_run/local_layers.py | 305 |
1 files changed, 305 insertions, 0 deletions
diff --git a/ep_run/local_layers.py b/ep_run/local_layers.py new file mode 100644 index 0000000..db73fb8 --- /dev/null +++ b/ep_run/local_layers.py @@ -0,0 +1,305 @@ +"""Local-learning variants of Linear supporting BP / FA / DFA / sign-sym methods. + +LocalLinear is a drop-in replacement for nn.Linear that selects its backward +computation based on the `method` argument: + + bp: standard autograd (nn.Linear behavior) + fa: custom autograd, backward uses a fixed random matrix B in place of W.T + (Lillicrap-style Feedback Alignment, per projection) + sign_sym: custom autograd, backward uses sign(W) in place of W.T (Xiao 2018) + dfa: forward uses normal autograd (so upstream params like embeddings / + LayerNorm still get BP gradients). Input is cached during forward. + After loss.backward(), call `apply_dfa_update(model, e_L)` to + OVERWRITE LocalLinear .grad with DFA-computed update. LocalLinear + weights thus receive direct projection updates while non-LocalLinear + params (embeddings, LN) retain BP gradients (pragmatic hybrid). + +For DFA, call `initialize_dfa_targets(model, target_dim)` once after model +construction, then each training step: + 1. standard forward + 2. compute loss, loss.backward() (fills BP .grad on everything) + 3. compute e_L = dL/dlogits analytically (for LM: softmax(logits)-onehot) + 4. call `apply_dfa_update(model, e_L)` to overwrite LocalLinear .grad + 5. optimizer.step() +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LinearFA(torch.autograd.Function): + """Linear forward, FA backward (replace W.T with fixed random B).""" + + @staticmethod + def forward(ctx, x, W, B, bias): + ctx.save_for_backward(x, W, B) + ctx.has_bias = bias is not None + out = x @ W.t() + if bias is not None: + out = out + bias + return out + + @staticmethod + def backward(ctx, grad_out): + x, W, B = ctx.saved_tensors + # True BP would use W here (shape out x in). FA replaces with random B of same shape. + grad_x = grad_out @ B + # grad_W is standard outer product of grad_out and x (summed over leading dims) + grad_W = grad_out.reshape(-1, grad_out.shape[-1]).t() @ x.reshape(-1, x.shape[-1]) + grad_B = None # B is fixed random + grad_bias = None + if ctx.has_bias: + grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0) + return grad_x, grad_W, grad_B, grad_bias + + +class LinearSignSym(torch.autograd.Function): + """Linear forward, sign-symmetric backward with rescaling. + + B = sign(W) · ||W||_F / sqrt(numel(W)) + The rescale matches sign(W)'s magnitude to W's typical element magnitude, + avoiding the 50x gradient blowup that pure sign(W) caused. + """ + + @staticmethod + def forward(ctx, x, W, bias): + ctx.save_for_backward(x, W) + ctx.has_bias = bias is not None + out = x @ W.t() + if bias is not None: + out = out + bias + return out + + @staticmethod + def backward(ctx, grad_out): + x, W = ctx.saved_tensors + # Rescaled sign: scale so that ||B||_F ≈ ||W||_F + scale = W.norm() / (W.numel() ** 0.5 + 1e-8) + sign_W_scaled = torch.sign(W) * scale + grad_x = grad_out @ sign_W_scaled + grad_W = grad_out.reshape(-1, grad_out.shape[-1]).t() @ x.reshape(-1, x.shape[-1]) + grad_bias = None + if ctx.has_bias: + grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0) + return grad_x, grad_W, grad_bias + + +class LocalLinear(nn.Module): + """nn.Linear drop-in with method-dispatched backward (bp/fa/dfa/sign_sym/dfa_block). + + fa_init_mode (only used when method='fa'): + gaussian: B ~ N(0, init_std) (Lillicrap default, current) + orthogonal: B = Haar orthogonal × scale (#1: JL-isometric, scaled to match BP grad norm) + ortho_he: B = Haar orthogonal × sqrt(2/out) (#2: He-init for backward signal) + sparse: B with k non-zeros per row, signs ±1, scaled (#4: structured sparse) + + fa_grape (only with method='fa'): if True, B is updated each step via cosine alignment + to the rank-1 JVP Jacobian estimate Ĵ = (W p) p^T. Implements GrAPE + (Caillon et al., ICLR 2026) per-layer. Forward only — no W^T transport. + """ + + def __init__(self, in_features, out_features, bias=False, method="bp", init_std=0.02, + fa_init_mode="gaussian", fa_sparse_k=None, fa_grape=False, fa_grape_n_probe=32): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.method = method + self._fa_grape = (method == "fa") and fa_grape + self._fa_grape_n_probe = fa_grape_n_probe + + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + nn.init.normal_(self.weight, mean=0.0, std=init_std) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + + if method == "fa": + B = torch.empty(out_features, in_features) + if fa_init_mode == "gaussian": + nn.init.normal_(B, mean=0.0, std=init_std) + elif fa_init_mode == "orthogonal": + # Haar orthogonal (semi-orthogonal for non-square), scaled to match BP grad norm. + # BP backward: grad_x = grad_out @ W has norm ~ sqrt(in) * std(W) * ||grad_out||. + # Pure orthogonal preserves norm (|| || stays). Scale to match BP's natural shrinkage. + nn.init.orthogonal_(B) + scale = (in_features ** 0.5) * init_std + B.mul_(scale) + elif fa_init_mode == "ortho_he": + # He init for backward: variance = 2/out_features (matches ReLU-friendly backward) + nn.init.orthogonal_(B) + scale = (2.0 / out_features) ** 0.5 + B.mul_(scale) + elif fa_init_mode == "sparse": + # k non-zero entries per row, signs ±1, scaled so row L2 norm matches Gaussian B + k = fa_sparse_k if fa_sparse_k is not None else max(1, in_features // 16) + B.zero_() + for i in range(out_features): + idx = torch.randperm(in_features)[:k] + signs = (torch.randint(0, 2, (k,)).float() * 2 - 1) + B[i, idx] = signs + # Scale to match Gaussian B's row variance: row_var(Gaussian) = in_features * init_std^2 + # row_var(sparse) = k (after scale 1). To match: scale = init_std * sqrt(in_features/k) + B.mul_(init_std * (in_features / k) ** 0.5) + else: + raise ValueError(f"Unknown fa_init_mode: {fa_init_mode}") + # GrAPE: B is updated via JVP-cosine alignment (not via standard optimizer); + # store as Parameter with requires_grad=False so we can update in-place. + if self._fa_grape: + self.B = nn.Parameter(B, requires_grad=False) + else: + self.register_buffer("B", B) + + if method == "dfa": + # B_dfa shape (out_features, target_dim); set via initialize_dfa_targets + self.register_buffer("B_dfa", None) + self._dfa_cached_input = None + + if method == "dfa_block": + # B_dfa_block: (out_features, d_block) — projects block-output-error to layer output + # Set via initialize_dfa_block_targets + self.register_buffer("B_dfa_block", None) + self._dfa_block_cached_input = None + + def set_dfa_target_dim(self, target_dim, init_std=0.02): + assert self.method == "dfa" + B = torch.empty(self.out_features, target_dim, device=self.weight.device) + nn.init.normal_(B, mean=0.0, std=init_std) + self.B_dfa = B + + @torch.no_grad() + def grape_align_step(self, lr_b=0.01, normalize_columns=False): + """GrAPE: update B toward rank-1 JVP Jacobian estimate via cosine alignment. + + For linear y = W x: J = W. JVP at random p: g = W p. Estimate Ĵ = (1/N) Σ g_i p_i^T → W. + Forward only (uses W in forward computation, no W^T transport). + + normalize_columns: paper (Eq. 6) does column-normalize B. We default off because + our per-linear FA needs magnitude match (B → W in magnitude AND direction); + column-norm prevents matching W's row magnitudes. + """ + if not self._fa_grape: + return + N = self._fa_grape_n_probe + device, dtype = self.weight.device, self.weight.dtype + # Random Gaussian perturbation p ~ N(0, I) in input space + p = torch.randn(N, self.in_features, device=device, dtype=dtype) + # Forward: g = p @ W^T (one matrix multiply) + g = F.linear(p, self.weight) # (N, out_features) + # Rank-1 Jacobian estimate: Ĵ = (1/N) g^T @ p, shape (out, in) + J_hat = (g.t() @ p) / N + # Cosine alignment gradient + B_norm = self.B.norm() + J_norm = J_hat.norm() + if B_norm < 1e-8 or J_norm < 1e-8: + return + cos_val = (self.B * J_hat).sum() / (B_norm * J_norm) + # ∂(1 - cos)/∂B = -Ĵ/(||B||·||Ĵ||) + cos · B/||B||² + grad = -J_hat / (B_norm * J_norm) + cos_val * self.B / (B_norm ** 2) + self.B.add_(grad, alpha=-lr_b) + if normalize_columns: + col_norms = self.B.norm(dim=0, keepdim=True).clamp_min(1e-8) + self.B.div_(col_norms) + + def forward(self, x): + if self.method == "bp": + return F.linear(x, self.weight, self.bias) + if self.method == "fa": + return LinearFA.apply(x, self.weight, self.B, self.bias) + if self.method == "sign_sym": + return LinearSignSym.apply(x, self.weight, self.bias) + if self.method == "dfa": + # Cache input for later manual DFA update (will overwrite BP .grad) + self._dfa_cached_input = x.detach() + return F.linear(x, self.weight, self.bias) + if self.method == "dfa_block": + self._dfa_block_cached_input = x.detach() + return F.linear(x, self.weight, self.bias) + raise ValueError(f"Unknown method: {self.method}") + + def dfa_compute_grad(self, e_L): + """Set self.weight.grad and self.bias.grad from global error e_L. + + e_L shape (..., target_dim). delta = e_L @ B_dfa.T, shape (..., out_features). + ΔW = sum_n delta_n outer input_n, where inputs are cached from forward. + """ + assert self.method == "dfa" + assert self._dfa_cached_input is not None, "DFA forward not called or cache cleared" + assert self.B_dfa is not None, "DFA target_dim not set (call initialize_dfa_targets)" + + delta = e_L @ self.B_dfa.t() # (..., out_features) + delta_flat = delta.reshape(-1, self.out_features) + inp_flat = self._dfa_cached_input.reshape(-1, self.in_features) + + grad_W = delta_flat.t() @ inp_flat # (out_features, in_features) + + if self.weight.grad is None: + self.weight.grad = grad_W.clone() + else: + self.weight.grad.copy_(grad_W) + + if self.bias is not None: + grad_b = delta_flat.sum(dim=0) + if self.bias.grad is None: + self.bias.grad = grad_b.clone() + else: + self.bias.grad.copy_(grad_b) + + self._dfa_cached_input = None + + def extra_repr(self): + return f"in={self.in_features}, out={self.out_features}, method={self.method}" + + +def initialize_dfa_targets(model, target_dim): + """Must be called once after model construction and device placement, for DFA mode.""" + for module in model.modules(): + if isinstance(module, LocalLinear) and module.method == "dfa": + module.set_dfa_target_dim(target_dim) + + +def apply_dfa_update(model, e_L): + """Iterate over all LocalLinear(dfa) modules and populate their .grad from e_L.""" + for module in model.modules(): + if isinstance(module, LocalLinear) and module.method == "dfa": + module.dfa_compute_grad(e_L) + + +def initialize_dfa_block_targets(model, d_block, init_std=0.02): + """For dfa_block mode: each LocalLinear gets a random B_dfa_block of shape (out, d_block).""" + for module in model.modules(): + if isinstance(module, LocalLinear) and module.method == "dfa_block": + B = torch.empty(module.out_features, d_block, device=module.weight.device) + nn.init.normal_(B, mean=0.0, std=init_std) + module.B_dfa_block = B + + +def apply_dfa_block_update(block, block_output_error): + """Apply DFA-within-block updates to all LocalLinear(dfa_block) in `block`. + + block_output_error: (B, T, d_block) — gradient at the block's output. + Each linear's grad: ΔW = (block_output_error @ B_dfa_block.T)^T @ cached_input + """ + err = block_output_error.detach() + err_flat = err.reshape(-1, err.size(-1)) # (N, d_block) + N = err_flat.size(0) + for module in block.modules(): + if isinstance(module, LocalLinear) and module.method == "dfa_block": + assert module.B_dfa_block is not None, "Call initialize_dfa_block_targets first" + assert module._dfa_block_cached_input is not None, "Forward not called" + # delta: (N, out_features) = err_flat @ B_dfa_block.T + delta = err_flat @ module.B_dfa_block.t() + inp_flat = module._dfa_block_cached_input.reshape(-1, module.in_features) + grad_W = (delta.t() @ inp_flat) / max(N, 1) + if module.weight.grad is None: + module.weight.grad = grad_W.clone() + else: + module.weight.grad.copy_(grad_W) + if module.bias is not None: + grad_b = delta.sum(dim=0) / max(N, 1) + if module.bias.grad is None: + module.bias.grad = grad_b.clone() + else: + module.bias.grad.copy_(grad_b) + module._dfa_block_cached_input = None |
