summaryrefslogtreecommitdiff
path: root/ep_run/eig_control.py
blob: 6e3f5987ec3daf4163b8c91a2c8b99745fe63722 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""#2 — leading-abscissa control for the ept non-conservative operator (ports the aep-dynamics
'control the LEADING spectral signal, surgically' finding to C512).

Why not jacreg: jacreg penalizes ||J_nc||_F^2 (Hutchinson) — the WHOLE Jacobian norm. That is blunt:
it over-constrains directions that never threaten stability, and when the controller ramps it high it
HIJACKS the task gradient (the known jr-hijack failure). The aep leading-vs-lagging result says the
right knob is the leading SPECTRAL ABSCISSA, not the norm.

What we control: the numerical abscissa  omega(J_nc) = lambda_max( (J_nc + J_nc^T)/2 )  = the 2-norm
log-norm mu_2(J_nc) = one-sided Lipschitz constant. It upper-bounds the spectral abscissa, governs the
transient growth ||e^{Jt}||, and its crossing past (1+c) IS the free-phase Hopf (J_F = J_nc - (1+c)I).
Power iteration on the SYMMETRIC PART -> matvec-only (jvp+vjp of nc_force, the same primitives jacreg
and the AEP correction already call), so it scales and is analog-compatible (no eigendecomposition).
One-sided ReLU penalty = a LEADING signal: acts only as the abscissa nears the margin, so unlike jacreg
it does not over-contract / hijack when the operator is already safe.
"""
import torch
from torch.autograd.functional import jvp, vjp


def num_abscissa(blk, zs, cache, iters=3):
    """Power-iterate Sym(J_nc)=(J_nc+J_nc^T)/2 at zs for the leading eigenpair. Returns (v_detached, lambda_float).
    lambda = v^T Sym(J_nc) v = v^T J_nc v (Rayleigh quotient at the leading eigenvector) = numerical abscissa."""
    z = zs.detach()
    v = cache.get('v')
    if v is None or v.shape != z.shape or v.dtype != z.dtype or v.device != z.device:
        v = torch.randn_like(z)
    v = v / (v.norm() + 1e-12)
    with torch.no_grad():
        for _ in range(iters):
            Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1])   # Sym(J_nc) v
            v = Sv / (Sv.norm() + 1e-12)
        lam = float((v * jvp(blk.nc_force, z, v)[1]).sum() / (v * v).sum())        # v^T J_nc v
    cache['v'] = v
    return v, lam


def eig_penalty(blk, zs, eigreg, margin, cache, iters=3):
    """Grads of the one-sided leading-abscissa penalty  R = eigreg * relu(omega(J_nc) - margin)^2.
    Returns ({id(p): grad}, omega) — omega logged as the controller signal. Empty grads when below margin."""
    v, lam0 = num_abscissa(blk, zs, cache, iters)
    if lam0 <= margin:                                     # below the stability margin: leading signal off
        return {}, lam0
    z = zs.detach()
    with torch.enable_grad():
        Jv = jvp(blk.nc_force, z, v, create_graph=True)[1]           # differentiable in theta (nc_force params)
        lam = (v * Jv).sum() / (v * v).sum()                        # numerical abscissa, v fixed
        R = eigreg * torch.relu(lam - margin) ** 2
        gr = torch.autograd.grad(R, blk.block, allow_unused=True)
    return {id(p): g for p, g in zip(blk.block, gr) if g is not None}, lam0