1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
|
"""#2 — leading-abscissa control for the ept non-conservative operator (ports the aep-dynamics
'control the LEADING spectral signal, surgically' finding to C512).
Why not jacreg: jacreg penalizes ||J_nc||_F^2 (Hutchinson) — the WHOLE Jacobian norm. That is blunt:
it over-constrains directions that never threaten stability, and when the controller ramps it high it
HIJACKS the task gradient (the known jr-hijack failure). The aep leading-vs-lagging result says the
right knob is the leading SPECTRAL ABSCISSA, not the norm.
What we control: the numerical abscissa omega(J_nc) = lambda_max( (J_nc + J_nc^T)/2 ) = the 2-norm
log-norm mu_2(J_nc) = one-sided Lipschitz constant. It upper-bounds the spectral abscissa, governs the
transient growth ||e^{Jt}||, and its crossing past (1+c) IS the free-phase Hopf (J_F = J_nc - (1+c)I).
Power iteration on the SYMMETRIC PART -> matvec-only (jvp+vjp of nc_force, the same primitives jacreg
and the AEP correction already call), so it scales and is analog-compatible (no eigendecomposition).
One-sided ReLU penalty = a LEADING signal: acts only as the abscissa nears the margin, so unlike jacreg
it does not over-contract / hijack when the operator is already safe.
"""
import torch
from torch.autograd.functional import jvp, vjp
def num_abscissa(blk, zs, cache, iters=3):
"""Power-iterate Sym(J_nc)=(J_nc+J_nc^T)/2 at zs for the leading eigenpair. Returns (v_detached, lambda_float).
lambda = v^T Sym(J_nc) v = v^T J_nc v (Rayleigh quotient at the leading eigenvector) = numerical abscissa."""
z = zs.detach()
v = cache.get('v')
if v is None or v.shape != z.shape or v.dtype != z.dtype or v.device != z.device:
v = torch.randn_like(z)
v = v / (v.norm() + 1e-12)
with torch.no_grad():
for _ in range(iters):
Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1]) # Sym(J_nc) v
v = Sv / (Sv.norm() + 1e-12)
lam = float((v * jvp(blk.nc_force, z, v)[1]).sum() / (v * v).sum()) # v^T J_nc v
cache['v'] = v
return v, lam
def eig_penalty(blk, zs, eigreg, margin, cache, iters=3):
"""Grads of the one-sided leading-abscissa penalty R = eigreg * relu(omega(J_nc) - margin)^2.
Returns ({id(p): grad}, omega) — omega logged as the controller signal. Empty grads when below margin."""
v, lam0 = num_abscissa(blk, zs, cache, iters)
if lam0 <= margin: # below the stability margin: leading signal off
return {}, lam0
z = zs.detach()
with torch.enable_grad():
Jv = jvp(blk.nc_force, z, v, create_graph=True)[1] # differentiable in theta (nc_force params)
lam = (v * Jv).sum() / (v * v).sum() # numerical abscissa, v fixed
R = eigreg * torch.relu(lam - margin) ** 2
gr = torch.autograd.grad(R, blk.block, allow_unused=True)
return {id(p): g for p, g in zip(blk.block, gr) if g is not None}, lam0
|