summaryrefslogtreecommitdiff
path: root/ep_run/eig_control.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/eig_control.py')
-rw-r--r--ep_run/eig_control.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/ep_run/eig_control.py b/ep_run/eig_control.py
new file mode 100644
index 0000000..6e3f598
--- /dev/null
+++ b/ep_run/eig_control.py
@@ -0,0 +1,50 @@
+"""#2 — leading-abscissa control for the ept non-conservative operator (ports the aep-dynamics
+'control the LEADING spectral signal, surgically' finding to C512).
+
+Why not jacreg: jacreg penalizes ||J_nc||_F^2 (Hutchinson) — the WHOLE Jacobian norm. That is blunt:
+it over-constrains directions that never threaten stability, and when the controller ramps it high it
+HIJACKS the task gradient (the known jr-hijack failure). The aep leading-vs-lagging result says the
+right knob is the leading SPECTRAL ABSCISSA, not the norm.
+
+What we control: the numerical abscissa omega(J_nc) = lambda_max( (J_nc + J_nc^T)/2 ) = the 2-norm
+log-norm mu_2(J_nc) = one-sided Lipschitz constant. It upper-bounds the spectral abscissa, governs the
+transient growth ||e^{Jt}||, and its crossing past (1+c) IS the free-phase Hopf (J_F = J_nc - (1+c)I).
+Power iteration on the SYMMETRIC PART -> matvec-only (jvp+vjp of nc_force, the same primitives jacreg
+and the AEP correction already call), so it scales and is analog-compatible (no eigendecomposition).
+One-sided ReLU penalty = a LEADING signal: acts only as the abscissa nears the margin, so unlike jacreg
+it does not over-contract / hijack when the operator is already safe.
+"""
+import torch
+from torch.autograd.functional import jvp, vjp
+
+
+def num_abscissa(blk, zs, cache, iters=3):
+ """Power-iterate Sym(J_nc)=(J_nc+J_nc^T)/2 at zs for the leading eigenpair. Returns (v_detached, lambda_float).
+ lambda = v^T Sym(J_nc) v = v^T J_nc v (Rayleigh quotient at the leading eigenvector) = numerical abscissa."""
+ z = zs.detach()
+ v = cache.get('v')
+ if v is None or v.shape != z.shape or v.dtype != z.dtype or v.device != z.device:
+ v = torch.randn_like(z)
+ v = v / (v.norm() + 1e-12)
+ with torch.no_grad():
+ for _ in range(iters):
+ Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1]) # Sym(J_nc) v
+ v = Sv / (Sv.norm() + 1e-12)
+ lam = float((v * jvp(blk.nc_force, z, v)[1]).sum() / (v * v).sum()) # v^T J_nc v
+ cache['v'] = v
+ return v, lam
+
+
+def eig_penalty(blk, zs, eigreg, margin, cache, iters=3):
+ """Grads of the one-sided leading-abscissa penalty R = eigreg * relu(omega(J_nc) - margin)^2.
+ Returns ({id(p): grad}, omega) — omega logged as the controller signal. Empty grads when below margin."""
+ v, lam0 = num_abscissa(blk, zs, cache, iters)
+ if lam0 <= margin: # below the stability margin: leading signal off
+ return {}, lam0
+ z = zs.detach()
+ with torch.enable_grad():
+ Jv = jvp(blk.nc_force, z, v, create_graph=True)[1] # differentiable in theta (nc_force params)
+ lam = (v * Jv).sum() / (v * v).sum() # numerical abscissa, v fixed
+ R = eigreg * torch.relu(lam - margin) ** 2
+ gr = torch.autograd.grad(R, blk.block, allow_unused=True)
+ return {id(p): g for p, g in zip(blk.block, gr) if g is not None}, lam0