diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/eig_control.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/eig_control.py')
| -rw-r--r-- | ep_run/eig_control.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/ep_run/eig_control.py b/ep_run/eig_control.py new file mode 100644 index 0000000..6e3f598 --- /dev/null +++ b/ep_run/eig_control.py @@ -0,0 +1,50 @@ +"""#2 — leading-abscissa control for the ept non-conservative operator (ports the aep-dynamics +'control the LEADING spectral signal, surgically' finding to C512). + +Why not jacreg: jacreg penalizes ||J_nc||_F^2 (Hutchinson) — the WHOLE Jacobian norm. That is blunt: +it over-constrains directions that never threaten stability, and when the controller ramps it high it +HIJACKS the task gradient (the known jr-hijack failure). The aep leading-vs-lagging result says the +right knob is the leading SPECTRAL ABSCISSA, not the norm. + +What we control: the numerical abscissa omega(J_nc) = lambda_max( (J_nc + J_nc^T)/2 ) = the 2-norm +log-norm mu_2(J_nc) = one-sided Lipschitz constant. It upper-bounds the spectral abscissa, governs the +transient growth ||e^{Jt}||, and its crossing past (1+c) IS the free-phase Hopf (J_F = J_nc - (1+c)I). +Power iteration on the SYMMETRIC PART -> matvec-only (jvp+vjp of nc_force, the same primitives jacreg +and the AEP correction already call), so it scales and is analog-compatible (no eigendecomposition). +One-sided ReLU penalty = a LEADING signal: acts only as the abscissa nears the margin, so unlike jacreg +it does not over-contract / hijack when the operator is already safe. +""" +import torch +from torch.autograd.functional import jvp, vjp + + +def num_abscissa(blk, zs, cache, iters=3): + """Power-iterate Sym(J_nc)=(J_nc+J_nc^T)/2 at zs for the leading eigenpair. Returns (v_detached, lambda_float). + lambda = v^T Sym(J_nc) v = v^T J_nc v (Rayleigh quotient at the leading eigenvector) = numerical abscissa.""" + z = zs.detach() + v = cache.get('v') + if v is None or v.shape != z.shape or v.dtype != z.dtype or v.device != z.device: + v = torch.randn_like(z) + v = v / (v.norm() + 1e-12) + with torch.no_grad(): + for _ in range(iters): + Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1]) # Sym(J_nc) v + v = Sv / (Sv.norm() + 1e-12) + lam = float((v * jvp(blk.nc_force, z, v)[1]).sum() / (v * v).sum()) # v^T J_nc v + cache['v'] = v + return v, lam + + +def eig_penalty(blk, zs, eigreg, margin, cache, iters=3): + """Grads of the one-sided leading-abscissa penalty R = eigreg * relu(omega(J_nc) - margin)^2. + Returns ({id(p): grad}, omega) — omega logged as the controller signal. Empty grads when below margin.""" + v, lam0 = num_abscissa(blk, zs, cache, iters) + if lam0 <= margin: # below the stability margin: leading signal off + return {}, lam0 + z = zs.detach() + with torch.enable_grad(): + Jv = jvp(blk.nc_force, z, v, create_graph=True)[1] # differentiable in theta (nc_force params) + lam = (v * Jv).sum() / (v * v).sum() # numerical abscissa, v fixed + R = eigreg * torch.relu(lam - margin) ** 2 + gr = torch.autograd.grad(R, blk.block, allow_unused=True) + return {id(p): g for p, g in zip(blk.block, gr) if g is not None}, lam0 |
