summaryrefslogtreecommitdiff
path: root/ep_run/asym_probe.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/asym_probe.py')
-rw-r--r--ep_run/asym_probe.py922
1 files changed, 922 insertions, 0 deletions
diff --git a/ep_run/asym_probe.py b/ep_run/asym_probe.py
new file mode 100644
index 0000000..1b61354
--- /dev/null
+++ b/ep_run/asym_probe.py
@@ -0,0 +1,922 @@
+"""Matrix-free asymmetry probe for the equilibrium-transformer block Jacobian.
+
+The state Jacobian J = dF/dz is never materialized. We estimate the growth of
+T = (S + mu I)^-1 A, where S=(J+J^T)/2 and A=(J-J^T)/2, using autograd JVP/VJP
+products at the relaxed fixed point.
+"""
+import argparse
+import glob
+import math
+import os
+import pickle
+import time
+import warnings
+
+warnings.filterwarnings("ignore", category=UserWarning)
+warnings.filterwarnings("ignore", message=".*cuBLAS.*", category=UserWarning)
+warnings.filterwarnings("ignore", message=".*CUBLAS.*", category=UserWarning)
+
+from dataclasses import dataclass, field
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy.sparse.linalg import LinearOperator, gmres, minres
+
+import lt_ep_train as L
+from lt_ep_train import EQBlock, bptt_step, ce, ep_step, relax
+
+
+def parse_args():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--ckpt", default="runs/ep_clean.pt")
+ ap.add_argument("--data", default="data/tinystories_bpe")
+ ap.add_argument("--gelu", default="erf")
+ ap.add_argument("--C", type=int, default=512)
+ ap.add_argument("--H", type=int, default=16)
+ ap.add_argument("--Mm", type=int, default=256)
+ ap.add_argument("--T", type=int, default=256)
+ ap.add_argument("--B", type=int, default=8)
+ ap.add_argument("--T1", type=int, default=150)
+ ap.add_argument("--T2", type=int, default=20)
+ ap.add_argument("--eps", type=float, default=0.1)
+ ap.add_argument("--beta", type=float, default=0.02)
+ ap.add_argument("--t1max", type=int, default=2000)
+ ap.add_argument("--relax-chunk", type=int, default=50)
+ ap.add_argument("--res-est", type=float, default=1e-4)
+ ap.add_argument("--t2sel", type=int, default=40)
+ ap.add_argument("--holo", type=int, default=2)
+ ap.add_argument("--hr", type=float, default=0.02)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--trace-probes", type=int, default=4)
+ ap.add_argument("--mu-scale", type=float, default=1e-3)
+ ap.add_argument("--mu", type=float, default=-1.0, help="override mu; negative means estimate from trace")
+ ap.add_argument("--solve-iters", type=int, default=80)
+ ap.add_argument("--solve-tol", type=float, default=1e-5)
+ ap.add_argument("--adjoint-iters", type=int, default=200)
+ ap.add_argument("--adjoint-tol", type=float, default=1e-5)
+ ap.add_argument("--adjoint-mu", type=float, default=1e-4, help="fallback Tikhonov mu for J^T+muI if GMRES stalls")
+ ap.add_argument("--rho-iters", type=int, default=20)
+ ap.add_argument("--rho-restarts", type=int, default=3)
+ ap.add_argument("--sigma-iters", type=int, default=8)
+ ap.add_argument("--sigma-restarts", type=int, default=2)
+ ap.add_argument("--arnoldi-k", type=int, default=12)
+ ap.add_argument("--skiprho", action=argparse.BooleanOptionalAction, default=True,
+ help="skip rho/sigma spectral probes and run only exact-adjoint gradient comparison")
+ ap.add_argument("--diag", action="store_true", help="run EP/exact-adjoint diagnostic suite and exit")
+ ap.add_argument("--noplot", action="store_true", help=argparse.SUPPRESS)
+ ap.add_argument("--lr", type=float, default=None, help=argparse.SUPPRESS)
+ ap.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
+ ap.add_argument("--tf32", action="store_true")
+ return ap.parse_args()
+
+
+def resolve_ckpt_path(path):
+ p = Path(path)
+ if p.is_absolute():
+ return str(p)
+ cwd_path = Path.cwd() / p
+ if cwd_path.exists():
+ return str(cwd_path)
+ return str(Path(__file__).resolve().parent / p)
+
+
+def require_cuda_if_requested(device):
+ if device != "cuda":
+ return
+ visible = os.environ.get("CUDA_VISIBLE_DEVICES")
+ ok = torch.cuda.is_available() and torch.cuda.device_count() > 0
+ if ok:
+ torch.cuda.set_device(0)
+ return
+ print("ERROR: CUDA unavailable; requested GPU0 run cannot start.", flush=True)
+ print(f"CUDA_VISIBLE_DEVICES={visible!r}", flush=True)
+ print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda}", flush=True)
+ print(f"torch.cuda.is_available()={torch.cuda.is_available()} device_count={torch.cuda.device_count()}", flush=True)
+ nodes = glob.glob("/dev/nvidia*")
+ required = ["/dev/nvidiactl", "/dev/nvidia-uvm", "/dev/nvidia0"]
+ missing = [p for p in required if not os.path.exists(p)]
+ print(f"/dev/nvidia*={' '.join(nodes) if nodes else 'MISSING'}", flush=True)
+ print(f"missing CUDA device nodes={' '.join(missing) if missing else 'none'}", flush=True)
+ raise SystemExit(2)
+
+
+def build_block(cfg, dev):
+ # Same construction and checkpoint-copy path as resreg_probe.py.
+ L.dev = dev
+ L.DD = Path(cfg.data)
+ L.vocab = pickle.load(open(L.DD / "meta.pkl", "rb"))["vocab_size"]
+ torch.manual_seed(cfg.seed)
+ blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=1.0, attn_mode="thick")
+ blk.qknorm = True
+ blk.fnoise = 0.0
+ blk._cstep = None
+ blk.navg = 1
+ blk.li_avg = 0
+ blk.track = True
+ blk.nbrake = 0.0
+ blk.gelu = cfg.gelu
+ ck = torch.load(cfg.ckpt, map_location=dev)
+ with torch.no_grad():
+ for p, w in zip(blk.allp, ck["allp"]):
+ p.copy_(w.to(dev))
+ return blk, ck
+
+
+@torch.no_grad()
+def residuals(blk, z, xin, eps):
+ z1 = relax(blk, z, xin, 1, eps)
+ zn = z.norm().item() + 1e-12
+ step_rel = (z1 - z).norm().item() / zn
+ force_rel = blk.tforce(z, xin).norm().item() / zn
+ return step_rel, force_rel
+
+
+def relax_to_fixed_point(blk, xin, cfg):
+ z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps)
+ step_rel, force_rel = residuals(blk, z, xin, cfg.eps)
+ steps = cfg.T1
+ while steps < cfg.t1max and step_rel > cfg.res_est:
+ chunk = min(cfg.relax_chunk, cfg.t1max - steps)
+ z = relax(blk, z, xin, chunk, cfg.eps)
+ steps += chunk
+ step_rel, force_rel = residuals(blk, z, xin, cfg.eps)
+ print(f"relax steps={steps:4d} step_res={step_rel:.3e} force_res={force_rel:.3e}", flush=True)
+ return z.detach(), steps, step_rel, force_rel
+
+
+def dot(a, b):
+ return torch.dot(a.reshape(-1), b.reshape(-1))
+
+
+def norm(a):
+ return torch.linalg.vector_norm(a.reshape(-1))
+
+
+def block_param_list(blk):
+ if hasattr(blk.block, "parameters"):
+ return list(blk.block.parameters())
+ return list(blk.block)
+
+
+def flat_grad_by_param_id(grads, params):
+ flat = []
+ for p in params:
+ g = grads.get(id(p)) if grads is not None else None
+ if g is None:
+ g = torch.zeros_like(p, device="cpu", dtype=torch.float64)
+ else:
+ g = g.detach().to(device="cpu", dtype=torch.float64)
+ flat.append(g.reshape(-1))
+ return torch.cat(flat)
+
+
+def set_param_requires_grad(blk, value):
+ for p in blk.allp:
+ p.requires_grad_(value)
+
+
+def cos(a, b):
+ return (torch.dot(a, b) / (norm(a) * norm(b) + 1e-20)).item()
+
+
+def rel_diff(a, b):
+ return (norm(a - b) / (norm(b) + 1e-20)).item()
+
+
+def unit_rand(shape, dev, dtype):
+ v = torch.randn(shape, device=dev, dtype=dtype)
+ return v / (norm(v) + 1e-30)
+
+
+@dataclass
+class SolveLog:
+ residuals: list = field(default_factory=list)
+ infos: list = field(default_factory=list)
+ iters: list = field(default_factory=list)
+
+ def add(self, rel_res, info, nit):
+ self.residuals.append(float(rel_res))
+ self.infos.append(int(info))
+ self.iters.append(int(nit))
+
+ def summary(self):
+ if not self.residuals:
+ return "solve residuals: none"
+ r = np.asarray(self.residuals, dtype=np.float64)
+ it = np.asarray(self.iters, dtype=np.int64)
+ bad = sum(1 for x in self.infos if x != 0)
+ return (
+ f"solve residuals: count={len(r)} min={r.min():.3e} "
+ f"median={np.median(r):.3e} max={r.max():.3e} "
+ f"iters median={np.median(it):.0f} max={it.max()} nonzero_info={bad}"
+ )
+
+
+class Operators:
+ def __init__(self, blk, zstar, xin, cfg, mu):
+ self.blk = blk
+ self.zstar = zstar.detach()
+ self.xin = xin.detach()
+ self.shape = tuple(zstar.shape)
+ self.n = zstar.numel()
+ self.dev = zstar.device
+ self.dtype = zstar.dtype
+ self.cfg = cfg
+ self.mu = float(mu)
+ self.solve_log = SolveLog()
+
+ def f(self, z):
+ return self.blk.tforce(z, self.xin)
+
+ def jv(self, v):
+ with torch.enable_grad():
+ _, out = torch.autograd.functional.jvp(
+ self.f, self.zstar, v.contiguous(), create_graph=False, strict=False
+ )
+ return out.detach()
+
+ def jtv(self, v):
+ with torch.enable_grad():
+ z = self.zstar.detach().requires_grad_(True)
+ fz = self.f(z)
+ (g,) = torch.autograd.grad(fz, z, grad_outputs=v.contiguous(), create_graph=False, retain_graph=False)
+ return g.detach()
+
+ def sv(self, v):
+ jv = self.jv(v)
+ jtv = self.jtv(v)
+ return 0.5 * (jv + jtv)
+
+ def av(self, v):
+ jv = self.jv(v)
+ jtv = self.jtv(v)
+ return 0.5 * (jv - jtv)
+
+ def smu(self, v, mu=None):
+ m = self.mu if mu is None else float(mu)
+ return self.sv(v) + m * v
+
+ def _from_numpy(self, x):
+ x = np.asarray(x, dtype=np.float32)
+ return torch.from_numpy(x).to(device=self.dev, dtype=self.dtype).view(self.shape)
+
+ def _to_numpy(self, x):
+ return x.detach().reshape(-1).float().cpu().numpy()
+
+ def solve_s(self, rhs, mu=None, tag=""):
+ m = self.mu if mu is None else float(mu)
+ b = self._to_numpy(rhs)
+ counter = {"n": 0}
+
+ def matvec(x_np):
+ x = self._from_numpy(x_np)
+ y = self.smu(x, m)
+ return self._to_numpy(y)
+
+ def cb(_x):
+ counter["n"] += 1
+
+ Aop = LinearOperator((self.n, self.n), matvec=matvec, dtype=np.dtype("float32"))
+ x_np, info = minres(Aop, b, rtol=self.cfg.solve_tol, maxiter=self.cfg.solve_iters, callback=cb, check=False)
+ x = self._from_numpy(x_np).detach()
+ rel = (norm(self.smu(x, m) - rhs) / (norm(rhs) + 1e-30)).item()
+ self.solve_log.add(rel, info, counter["n"])
+ if tag:
+ print(f"solve {tag}: mu={m:.3e} rel_res={rel:.3e} iters={counter['n']} info={info}", flush=True)
+ return x, rel, info, counter["n"]
+
+ def solve_jt_gmres(self, rhs, tol, maxiter, mu=0.0, tag="adjoint"):
+ m = float(mu)
+ b = self._to_numpy(rhs)
+ counter = {"n": 0}
+ restart = max(1, min(50, int(maxiter)))
+
+ def matvec(x_np):
+ x = self._from_numpy(x_np)
+ y = self.jtv(x)
+ if m != 0.0:
+ y = y + m * x
+ return self._to_numpy(y)
+
+ def cb(_arg):
+ counter["n"] += 1
+
+ Aop = LinearOperator((self.n, self.n), matvec=matvec, dtype=np.dtype("float32"))
+ try:
+ x_np, info = gmres(
+ Aop,
+ b,
+ rtol=tol,
+ atol=0.0,
+ restart=restart,
+ maxiter=int(maxiter),
+ callback=cb,
+ callback_type="legacy",
+ )
+ except TypeError:
+ x_np, info = gmres(Aop, b, tol=tol, restart=restart, maxiter=int(maxiter), callback=cb)
+ x = self._from_numpy(x_np).detach()
+ rel = (norm(self.jtv(x) + m * x - rhs) / (norm(rhs) + 1e-30)).item()
+ print(f"GMRES {tag}: mu={m:.3e} rel_res={rel:.3e} iters={counter['n']} info={info}", flush=True)
+ return x, rel, info, counter["n"]
+
+ def t(self, v):
+ rhs = self.av(v)
+ x, _, _, _ = self.solve_s(rhs)
+ return x
+
+ def tt(self, u):
+ y, _, _, _ = self.solve_s(u)
+ return -self.av(y)
+
+
+def estimate_trace_s(op, probes):
+ vals = []
+ for i in range(probes):
+ r = torch.randint(0, 2, op.shape, device=op.dev, dtype=torch.int8).to(op.dtype)
+ r = r.mul_(2).sub_(1)
+ sr = op.sv(r)
+ vals.append((dot(r, sr) / op.n).item())
+ print(f"trace probe {i}: tr(S)/n={vals[-1]:+.6e}", flush=True)
+ return float(np.mean(vals)), float(np.std(vals) if len(vals) > 1 else 0.0)
+
+
+def sensitivity_probe(op, mu):
+ v = unit_rand(op.shape, op.dev, op.dtype)
+ rhs = op.av(v)
+ xb, rb, _, _ = op.solve_s(rhs, mu=mu, tag="sensitivity/base")
+ rows = []
+ for scale in (0.1, 10.0):
+ ms = max(mu * scale, 0.0)
+ xa, ra, _, _ = op.solve_s(rhs, mu=ms, tag=f"sensitivity/mu_x{scale:g}")
+ rel_dx = (norm(xa - xb) / (norm(xb) + 1e-30)).item()
+ rows.append((scale, ms, rel_dx, ra))
+ print(
+ "solve sensitivity: "
+ + " ".join(f"mu_x{scale:g}: rel_dx={dx:.3e} rel_res={rr:.3e}" for scale, _, dx, rr in rows),
+ flush=True,
+ )
+ return rb, rows
+
+
+def power_rho(op, cfg):
+ best = 0.0
+ best_hist = None
+ for r in range(cfg.rho_restarts):
+ v = unit_rand(op.shape, op.dev, op.dtype)
+ hist = []
+ for i in range(cfg.rho_iters):
+ w = op.t(v)
+ growth = norm(w).item()
+ rq = (dot(v, w) / (dot(v, v) + 1e-30)).item()
+ hist.append((growth, rq))
+ if growth <= 1e-30 or not math.isfinite(growth):
+ break
+ v = (w / growth).detach()
+ print(f"rho restart={r} iter={i + 1:02d} growth={growth:.6e} rayleigh={rq:+.6e}", flush=True)
+ if hist and hist[-1][0] > best:
+ best = hist[-1][0]
+ best_hist = hist
+ if best_hist:
+ trend = " ".join(f"{g:.3g}" for g, _ in best_hist[-min(6, len(best_hist)):])
+ rtrend = " ".join(f"{rq:+.3g}" for _, rq in best_hist[-min(6, len(best_hist)):])
+ print(f"rho power trend last={trend}", flush=True)
+ print(f"rho Rayleigh trend last={rtrend}", flush=True)
+ return best
+
+
+def arnoldi_rho(op, k):
+ if k <= 0:
+ return None
+ q = unit_rand(op.shape, op.dev, op.dtype)
+ Q = [q]
+ H = np.zeros((k + 1, k), dtype=np.float64)
+ m = 0
+ for j in range(k):
+ w = op.t(Q[j])
+ for i in range(j + 1):
+ hij = dot(Q[i], w).item()
+ H[i, j] = hij
+ w = w - hij * Q[i]
+ hnext = norm(w).item()
+ H[j + 1, j] = hnext
+ m = j + 1
+ print(f"arnoldi iter={j + 1:02d} h_next={hnext:.6e}", flush=True)
+ if hnext < 1e-12:
+ break
+ if j + 1 < k:
+ Q.append((w / hnext).detach())
+ eig = np.linalg.eigvals(H[:m, :m])
+ rho = float(np.max(np.abs(eig))) if eig.size else float("nan")
+ print(f"rho Arnoldi(k={m})={rho:.6e}", flush=True)
+ return rho
+
+
+def power_sigma(op, cfg):
+ best = 0.0
+ for r in range(cfg.sigma_restarts):
+ v = unit_rand(op.shape, op.dev, op.dtype)
+ sigma = 0.0
+ for i in range(cfg.sigma_iters):
+ u = op.t(v)
+ sigma = norm(u).item()
+ w = op.tt(u)
+ wn = norm(w).item()
+ if wn <= 1e-30 or not math.isfinite(wn):
+ break
+ v = (w / wn).detach()
+ print(f"sigma restart={r} iter={i + 1:02d} sigma={sigma:.6e}", flush=True)
+ best = max(best, sigma)
+ return best
+
+
+def ce_state_grad(blk, zstar, y):
+ with torch.enable_grad():
+ z = zstar.detach().requires_grad_(True)
+ loss = ce(blk, z, y)
+ (ell,) = torch.autograd.grad(loss, z)
+ return ell.detach(), float(loss.detach())
+
+
+def solve_exact_adjoint(op, ell, cfg):
+ rhs = -ell.detach()
+ lam, rel, info, nit = op.solve_jt_gmres(rhs, cfg.adjoint_tol, cfg.adjoint_iters, mu=0.0, tag="J^T lambda=-ell")
+ stalled = (info != 0) or (not math.isfinite(rel)) or (rel > max(10.0 * cfg.adjoint_tol, 1e-4))
+ mu_used = 0.0
+ if stalled:
+ mu_used = max(float(cfg.adjoint_mu), 1e-8)
+ print(f"GMRES stalled; retrying exact-adjoint solve with Tikhonov J^T+muI, mu={mu_used:.3e}", flush=True)
+ lam, rel, info, nit = op.solve_jt_gmres(
+ rhs, cfg.adjoint_tol, cfg.adjoint_iters, mu=mu_used, tag="(J^T+muI) lambda=-ell"
+ )
+ return lam.detach(), rel, info, nit, mu_used
+
+
+def exact_transpose_grad(blk, idx, zstar, xin0, lam, params):
+ for p in blk.allp:
+ p.requires_grad_(True)
+ with torch.enable_grad():
+ # Value stays at the relaxed clamp xin0, while tok/pos receive the same clamp-gradient path as the trainer.
+ xin = xin0 + (blk.embed(idx) - blk.embed(idx).detach())
+ force = blk.tforce(zstar.detach(), xin)
+ grads = torch.autograd.grad((force * lam.detach()).sum(), params, allow_unused=True)
+ return {id(p): g for p, g in zip(params, grads)}
+
+
+def run_ep_step_flat(blk, idx, y, cfg, params, *, beta=None, holo=None, hr=None, t2sel=None, track=None, T2=None):
+ saved_track = getattr(blk, "track", None)
+ if track is not None:
+ blk.track = bool(track)
+ try:
+ set_param_requires_grad(blk, True)
+ # Mirrors lt_ep_train.ep_step:
+ # (blk, idx, y, T1, T2, eps, beta, jacreg, holo, hr, t1max, res_est, t2sel, corr_every, res_gate, resreg).
+ grads, ep_res = ep_step(
+ blk,
+ idx,
+ y,
+ cfg.T1,
+ cfg.T2 if T2 is None else int(T2),
+ cfg.eps,
+ cfg.beta if beta is None else float(beta),
+ 0.0,
+ cfg.holo if holo is None else int(holo),
+ cfg.hr if hr is None else float(hr),
+ cfg.t1max,
+ cfg.res_est,
+ cfg.t2sel if t2sel is None else int(t2sel),
+ 1,
+ 0.0,
+ 0.0,
+ )
+ return flat_grad_by_param_id(grads, params), float(ep_res)
+ finally:
+ if track is not None and saved_track is not None:
+ blk.track = saved_track
+
+
+@torch.no_grad()
+def fixed_point_step_abs(blk, zstar, xin, eps):
+ return (relax(blk, zstar, xin, 1, eps) - zstar).norm().item()
+
+
+def exact_reference_for_batch(blk, idx, y, cfg, label, compute_bptt=True):
+ print(f"--- exact reference: {label} ---", flush=True)
+ xin0 = blk.embed(idx).detach()
+ zstar, steps, step_res, force_res = relax_to_fixed_point(blk, xin0, cfg)
+ step_abs = fixed_point_step_abs(blk, zstar, xin0, cfg.eps)
+ print(
+ f"{label}: z* residual step_abs={step_abs:.6e} step_rel={step_res:.6e} "
+ f"force_rel={force_res:.6e} relax_steps={steps}",
+ flush=True,
+ )
+ if step_res > cfg.res_est:
+ print(f"{label}: WARNING step_res={step_res:.3e} > res_est={cfg.res_est:.3e}", flush=True)
+
+ set_param_requires_grad(blk, False)
+ op = Operators(blk, zstar, xin0, cfg, mu=0.0)
+ ell, ce_loss = ce_state_grad(blk, zstar, y)
+ print(f"{label}: CE(z*)={ce_loss:.6f} ||ell||={norm(ell).item():.6e}", flush=True)
+ lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg)
+ print(
+ f"{label}: adjoint residual={gmres_rel:.3e} iters={gmres_iters} info={gmres_info} "
+ f"tikhonov_mu={adj_mu:.3e}",
+ flush=True,
+ )
+
+ params = block_param_list(blk)
+ gt = flat_grad_by_param_id(exact_transpose_grad(blk, idx, zstar, xin0, lam, params), params)
+ out = {
+ "idx": idx,
+ "y": y,
+ "params": params,
+ "gt": gt,
+ "z_step_abs": step_abs,
+ "z_step_rel": step_res,
+ "z_force_rel": force_res,
+ "relax_steps": steps,
+ "gmres_rel": gmres_rel,
+ "gmres_info": gmres_info,
+ "gmres_iters": gmres_iters,
+ "adj_mu": adj_mu,
+ "ce_loss": ce_loss,
+ }
+ if compute_bptt:
+ set_param_requires_grad(blk, True)
+ gB = bptt_step(blk, idx, y, cfg.T1, cfg.eps, 0.0)
+ out["gBv"] = flat_grad_by_param_id(gB, params)
+ set_param_requires_grad(blk, True)
+ return out
+
+
+def draw_seeded_train_batch(cfg, seed):
+ torch.manual_seed(int(seed))
+ return L.get_batch("train", cfg.B, cfg.T)
+
+
+def finite_range(vals):
+ ok = [float(v) for v in vals if v is not None and math.isfinite(float(v))]
+ if not ok:
+ return None
+ return float(np.mean(ok)), float(np.min(ok)), float(np.max(ok))
+
+
+def read_multi_batch(rows):
+ vals = [r.get("cos_ep_t") for r in rows if r.get("ok")]
+ stats = finite_range(vals)
+ if stats is None:
+ return "no successful batches"
+ mean, mn, mx = stats
+ spread = mx - mn
+ if mn > 0.95:
+ return "consistently aligned across batches"
+ if mx < 0.80:
+ return "systematically low across batches"
+ if spread > 0.20:
+ return "batch-variance/outlier behavior is material"
+ return "mostly systematic with moderate batch variance"
+
+
+def read_beta_sweep(rows):
+ ok = [(r["beta"], r["cos"]) for r in rows if r.get("ok")]
+ if not ok:
+ return "no successful beta points"
+ first_beta, first_cos = ok[0]
+ last_beta, last_cos = ok[-1]
+ best_cos = max(c for _, c in ok)
+ if last_cos > 0.95 and last_cos - first_cos > 0.10:
+ return f"finite-beta bias likely: cos improves from beta={first_beta:g} to beta={last_beta:g}"
+ if best_cos < 0.80:
+ return "cos stays low as beta shrinks: structural/bug more likely than finite-beta bias"
+ if last_cos > first_cos + 0.05:
+ return "some finite-beta sensitivity, but not a clean convergence-to-1 result"
+ return "no strong beta-to-zero improvement"
+
+
+def read_ablation(rows):
+ ok = [r for r in rows if r.get("ok")]
+ if not ok:
+ return "no successful ablations"
+ full = next((r for r in ok if r["key"] == "full"), None)
+ best = max(ok, key=lambda r: r["cos"])
+ if full is None:
+ return f"best successful config is {best['label']}"
+ delta = best["cos"] - full["cos"]
+ if delta <= 0.05:
+ return "no ablation materially improves over FULL"
+ if best["key"] == "track_off":
+ return "tracking path is suspect: disabling blk.track improved cos"
+ if best["key"] == "plain":
+ return "holomorphic/adaptive path is suspect: plain real EP improved cos"
+ if best["key"] == "fixed_t2":
+ return "adaptive-T2 selection/tracking is suspect: fixed T2 improved cos"
+ return f"{best['label']} is the strongest improvement over FULL"
+
+
+def print_diagnostic_summary(multi_rows, beta_rows, ablation_rows):
+ print("", flush=True)
+ print("================ DIAGNOSTIC SUMMARY ================", flush=True)
+
+ multi_stats_t = finite_range([r.get("cos_ep_t") for r in multi_rows if r.get("ok")])
+ multi_stats_b = finite_range([r.get("cos_ep_b") for r in multi_rows if r.get("ok")])
+ if multi_stats_t is None:
+ print("Multi-batch: no successful batches", flush=True)
+ else:
+ mean, mn, mx = multi_stats_t
+ print(f"Multi-batch: mean cos(g_EP,g_transpose)={mean:+.6f} range=[{mn:+.6f}, {mx:+.6f}]", flush=True)
+ if multi_stats_b is not None:
+ mean, mn, mx = multi_stats_b
+ print(f"Multi-batch: mean cos(g_EP,g_BPTT)={mean:+.6f} range=[{mn:+.6f}, {mx:+.6f}]", flush=True)
+ print(f"Multi-batch read: {read_multi_batch(multi_rows)}", flush=True)
+
+ print("Beta sweep (beta | cos(g_EP,g_transpose)):", flush=True)
+ if beta_rows:
+ for row in beta_rows:
+ if row.get("ok"):
+ print(f" {row['beta']:<8g} | {row['cos']:+.6f}", flush=True)
+ else:
+ print(f" {row.get('beta', 'n/a')!s:<8} | failed: {row.get('error')}", flush=True)
+ else:
+ print(" none", flush=True)
+ print(f"Beta sweep read: {read_beta_sweep(beta_rows)}", flush=True)
+
+ print("Ablation (config | cos(g_EP,g_transpose)):", flush=True)
+ if ablation_rows:
+ for row in ablation_rows:
+ if row.get("ok"):
+ print(f" {row['label']} | {row['cos']:+.6f}", flush=True)
+ else:
+ print(f" {row.get('label', 'unknown')} | failed: {row.get('error')}", flush=True)
+ else:
+ print(" none", flush=True)
+ print(f"Ablation read: {read_ablation(ablation_rows)}", flush=True)
+ print("============== END DIAGNOSTIC SUMMARY ==============", flush=True)
+
+
+def run_diagnostics(blk, cfg, ck):
+ print("=== DIAGNOSTIC MODE ===", flush=True)
+ print(f"# ckpt step {ck.get('step')} best {ck.get('best')}", flush=True)
+ print(
+ "ep_step paths: holo=2,t2sel>0,track=True -> holo_a_track; "
+ "holo=2,t2sel>0,track=False -> holo_a_select2; holo>0,t2sel=0 -> holo_a; holo=0 -> plain EP",
+ flush=True,
+ )
+ print("gradient comparison scope: blk.block parameters; readout Wh is excluded", flush=True)
+
+ multi_rows = []
+ beta_rows = []
+ ablation_rows = []
+ seed1000_ref = None
+
+ print("=== DIAGNOSTIC 1: MULTI-BATCH ===", flush=True)
+ for i in range(6):
+ seed = 1000 + i
+ label = f"diag1 batch={i} seed={seed}"
+ try:
+ idx, y = draw_seeded_train_batch(cfg, seed)
+ ref = exact_reference_for_batch(blk, idx, y, cfg, label, compute_bptt=True)
+ torch.manual_seed(seed)
+ gEPv, ep_res = run_ep_step_flat(blk, idx, y, cfg, ref["params"])
+ row = {
+ "ok": True,
+ "batch": i,
+ "seed": seed,
+ "cos_ep_t": cos(gEPv, ref["gt"]),
+ "cos_ep_b": cos(gEPv, ref["gBv"]),
+ "cos_t_b": cos(ref["gt"], ref["gBv"]),
+ "z_step_abs": ref["z_step_abs"],
+ "z_step_rel": ref["z_step_rel"],
+ "z_force_rel": ref["z_force_rel"],
+ "ep_res": ep_res,
+ }
+ multi_rows.append(row)
+ print(
+ f"{label}: cos(g_EP,g_transpose)={row['cos_ep_t']:+.6f} "
+ f"cos(g_EP,g_BPTT)={row['cos_ep_b']:+.6f} "
+ f"cos(g_transpose,g_BPTT)={row['cos_t_b']:+.6f} "
+ f"z_res_abs={row['z_step_abs']:.6e} z_res_rel={row['z_step_rel']:.6e} ep_res={ep_res:.6e}",
+ flush=True,
+ )
+ if seed == 1000:
+ seed1000_ref = ref
+ except Exception as err:
+ row = {"ok": False, "batch": i, "seed": seed, "error": repr(err)}
+ multi_rows.append(row)
+ print(f"{label} failed: {err!r}", flush=True)
+
+ multi_stats_t = finite_range([r.get("cos_ep_t") for r in multi_rows if r.get("ok")])
+ multi_stats_b = finite_range([r.get("cos_ep_b") for r in multi_rows if r.get("ok")])
+ if multi_stats_t is not None and multi_stats_b is not None:
+ mt, mint, maxt = multi_stats_t
+ mb, minb, maxb = multi_stats_b
+ print(
+ f"DIAG1 aggregate: cos(g_EP,g_transpose) mean={mt:+.6f} min={mint:+.6f} max={maxt:+.6f}; "
+ f"cos(g_EP,g_BPTT) mean={mb:+.6f} min={minb:+.6f} max={maxb:+.6f}",
+ flush=True,
+ )
+
+ if seed1000_ref is None:
+ try:
+ idx, y = draw_seeded_train_batch(cfg, 1000)
+ seed1000_ref = exact_reference_for_batch(blk, idx, y, cfg, "diag seed=1000 fallback", compute_bptt=True)
+ except Exception as err:
+ print(f"seed=1000 reference failed; beta sweep and ablation cannot run: {err!r}", flush=True)
+
+ print("=== DIAGNOSTIC 2: BETA SWEEP ===", flush=True)
+ if seed1000_ref is not None:
+ for beta in [0.04, 0.02, 0.01, 0.005, 0.002]:
+ try:
+ torch.manual_seed(1000)
+ gEPv, ep_res = run_ep_step_flat(
+ blk,
+ seed1000_ref["idx"],
+ seed1000_ref["y"],
+ cfg,
+ seed1000_ref["params"],
+ beta=beta,
+ hr=beta,
+ )
+ row = {"ok": True, "beta": beta, "cos": cos(gEPv, seed1000_ref["gt"]), "ep_res": ep_res}
+ beta_rows.append(row)
+ print(f"beta={beta:g} hr={beta:g}: cos(g_EP,g_transpose)={row['cos']:+.6f} ep_res={ep_res:.6e}", flush=True)
+ except Exception as err:
+ beta_rows.append({"ok": False, "beta": beta, "error": repr(err)})
+ print(f"beta={beta:g} failed: {err!r}", flush=True)
+ else:
+ print("DIAG2 skipped: seed=1000 reference unavailable", flush=True)
+
+ print("=== DIAGNOSTIC 3: COMPONENT ABLATION ===", flush=True)
+ if seed1000_ref is not None:
+ ablations = [
+ {
+ "key": "full",
+ "label": "FULL holo=2 track=True t2sel=40",
+ "kwargs": {"holo": 2, "track": True, "t2sel": 40, "hr": cfg.hr, "beta": cfg.beta},
+ },
+ {
+ "key": "track_off",
+ "label": "holo=2 track=False t2sel=40",
+ "kwargs": {"holo": 2, "track": False, "t2sel": 40, "hr": cfg.hr, "beta": cfg.beta},
+ },
+ {
+ "key": "plain",
+ "label": "plain EP holo=0 track ignored t2sel=0",
+ "kwargs": {"holo": 0, "track": getattr(blk, "track", False), "t2sel": 0, "hr": cfg.hr, "beta": cfg.beta},
+ },
+ {
+ "key": "fixed_t2",
+ "label": f"holo=2 track=True t2sel=0 fixed T2={cfg.T2}",
+ "kwargs": {"holo": 2, "track": True, "t2sel": 0, "hr": cfg.hr, "beta": cfg.beta, "T2": cfg.T2},
+ },
+ ]
+ for item in ablations:
+ try:
+ torch.manual_seed(1000)
+ gEPv, ep_res = run_ep_step_flat(
+ blk,
+ seed1000_ref["idx"],
+ seed1000_ref["y"],
+ cfg,
+ seed1000_ref["params"],
+ **item["kwargs"],
+ )
+ row = {
+ "ok": True,
+ "key": item["key"],
+ "label": item["label"],
+ "cos": cos(gEPv, seed1000_ref["gt"]),
+ "ep_res": ep_res,
+ }
+ ablation_rows.append(row)
+ print(f"{item['label']}: cos(g_EP,g_transpose)={row['cos']:+.6f} ep_res={ep_res:.6e}", flush=True)
+ except Exception as err:
+ row = {"ok": False, "key": item["key"], "label": item["label"], "error": repr(err)}
+ ablation_rows.append(row)
+ print(f"config {item['label']} failed: {err!r}", flush=True)
+ else:
+ print("DIAG3 skipped: seed=1000 reference unavailable", flush=True)
+
+ print_diagnostic_summary(multi_rows, beta_rows, ablation_rows)
+
+
+def compare_exact_adjoint(blk, idx, y, zstar, xin0, op, cfg):
+ print("=== exact-adjoint gradient comparison ===", flush=True)
+ ell, ce_loss = ce_state_grad(blk, zstar, y)
+ print(f"CE(z*)={ce_loss:.6f} ||ell||={norm(ell).item():.6e}", flush=True)
+ lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg)
+ print(
+ f"adjoint solve summary: residual={gmres_rel:.3e} iters={gmres_iters} info={gmres_info} "
+ f"tikhonov_mu={adj_mu:.3e}",
+ flush=True,
+ )
+
+ params = block_param_list(blk)
+ gt = flat_grad_by_param_id(exact_transpose_grad(blk, idx, zstar, xin0, lam, params), params)
+ print("gradient comparison scope: blk.block parameters; readout Wh is excluded", flush=True)
+
+ gB = bptt_step(blk, idx, y, cfg.T1, cfg.eps, 0.0)
+ gEP, ep_res = ep_step(
+ blk,
+ idx,
+ y,
+ cfg.T1,
+ cfg.T2,
+ cfg.eps,
+ cfg.beta,
+ 0.0,
+ cfg.holo,
+ cfg.hr,
+ cfg.t1max,
+ cfg.res_est,
+ cfg.t2sel,
+ 1,
+ 0.0,
+ )
+ gBv = flat_grad_by_param_id(gB, params)
+ gEPv = flat_grad_by_param_id(gEP, params)
+
+ print(f"EP estimator free-phase residual from ep_step={ep_res:.6e}", flush=True)
+ print(f"||g_transpose||={norm(gt).item():.6e} ||g_BPTT||={norm(gBv).item():.6e} ||g_EP||={norm(gEPv).item():.6e}", flush=True)
+ c_t_b = cos(gt, gBv)
+ d_t_b = rel_diff(gt, gBv)
+ c_ep_t = cos(gEPv, gt)
+ d_ep_t = rel_diff(gEPv, gt)
+ c_ep_b = cos(gEPv, gBv)
+ d_ep_b = rel_diff(gEPv, gBv)
+ print(f"cos(g_transpose, g_BPTT)={c_t_b:+.6f} ||g_transpose-g_BPTT||/||g_BPTT||={d_t_b:.6e}", flush=True)
+ print(f"cos(g_EP, g_transpose)={c_ep_t:+.6f} ||g_EP-g_transpose||/||g_transpose||={d_ep_t:.6e}", flush=True)
+ print(f"cos(g_EP, g_BPTT)={c_ep_b:+.6f} ||g_EP-g_BPTT||/||g_BPTT||={d_ep_b:.6e}", flush=True)
+ print("interpretation:", flush=True)
+ print(" cos(g_transpose,g_BPTT)~1 AND cos(g_EP,g_transpose)~1 -> our EP IS the exact adjoint; failure is convergence/contraction", flush=True)
+ print(" cos(g_transpose,g_BPTT)~1 AND cos(g_EP,g_transpose)<1 -> exact adjoint works, our EP falls short -> implement exact/dyadic", flush=True)
+ print(" cos(g_transpose,g_BPTT)<1 -> even exact adjoint != BPTT -> finite-time/convergence, not the adjoint", flush=True)
+
+
+def main():
+ cfg = parse_args()
+ cfg.ckpt = resolve_ckpt_path(cfg.ckpt)
+ require_cuda_if_requested(cfg.device)
+ dev = torch.device("cuda:0" if cfg.device == "cuda" else "cpu")
+ torch.backends.cuda.matmul.allow_tf32 = bool(cfg.tf32)
+ torch.backends.cudnn.allow_tf32 = bool(cfg.tf32)
+ print(f"# asym_probe device={dev} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r}", flush=True)
+ print(
+ f"# ckpt={cfg.ckpt} B={cfg.B} T={cfg.T} C={cfg.C} H={cfg.H} Mm={cfg.Mm} "
+ f"attn_mode=thick qknorm=True gelu={cfg.gelu}",
+ flush=True,
+ )
+ blk, ck = build_block(cfg, dev)
+ if cfg.diag:
+ run_diagnostics(blk, cfg, ck)
+ return
+
+ idx, y = L.get_batch("train", cfg.B, cfg.T)
+ xin0 = blk.embed(idx).detach()
+ zstar, steps, step_res, force_res = relax_to_fixed_point(blk, xin0, cfg)
+ print(f"# ckpt step {ck.get('step')} best {ck.get('best')}", flush=True)
+ print(f"z* residual: step_res={step_res:.6e} force_res={force_res:.6e} relax_steps={steps}", flush=True)
+ if step_res > cfg.res_est:
+ print(f"WARNING: fixed-point target not reached: step_res={step_res:.3e} > {cfg.res_est:.3e}", flush=True)
+ if step_res > 1e-3 or force_res > 1e-3:
+ print("WARNING: relaxed z* residual exceeds 1e-3; do not trust exact-adjoint solves until convergence improves", flush=True)
+
+ # Freeze parameters for state Jacobian products. tforce is out-of-place; each
+ # VJP re-leafs z* to avoid stale graphs, and xin0 is held detached/fixed.
+ set_param_requires_grad(blk, False)
+ print("autograd note: using blk.tforce directly; no in-place tforce ops patched; z* is re-leafed per VJP/JVP", flush=True)
+
+ op0 = Operators(blk, zstar, xin0, cfg, mu=0.0)
+ if cfg.skiprho:
+ compare_exact_adjoint(blk, idx, y, zstar, xin0, op0, cfg)
+ return
+
+ tr_mean, tr_std = estimate_trace_s(op0, cfg.trace_probes)
+ if cfg.mu >= 0:
+ mu = float(cfg.mu)
+ else:
+ mu = cfg.mu_scale * max(abs(tr_mean), 1e-12)
+ print(f"trace(S)/n estimate={tr_mean:+.6e} std={tr_std:.3e}", flush=True)
+ print(f"mu used={mu:.6e} (mu_scale={cfg.mu_scale:g}, solve operator S+muI)", flush=True)
+
+ op = Operators(blk, zstar, xin0, cfg, mu=mu)
+ sensitivity_probe(op, mu)
+ t0 = time.time()
+ rho_power = power_rho(op, cfg)
+ rho_arnoldi = arnoldi_rho(op, cfg.arnoldi_k)
+ sigma = power_sigma(op, cfg)
+ elapsed = time.time() - t0
+ rho = max(rho_power, rho_arnoldi if rho_arnoldi is not None else 0.0)
+ print(op.solve_log.summary(), flush=True)
+ print("non-normal note: power iteration reports dominant growth; Rayleigh trend may be small/oscillatory for skew modes", flush=True)
+ print(f"rho(S^-1 A)={rho:.6e} power={rho_power:.6e} arnoldi={rho_arnoldi if rho_arnoldi is not None else float('nan'):.6e}", flush=True)
+ print(f"||S^-1 A||_2={sigma:.6e}", flush=True)
+ print(f"elapsed_operator_seconds={elapsed:.1f}", flush=True)
+ verdict = "higher-order AEP viable" if rho < 1.0 else "higher-order AEP not viable"
+ print(f"VERDICT: rho {'<' if rho < 1.0 else '>='} 1 => {verdict}", flush=True)
+ compare_exact_adjoint(blk, idx, y, zstar, xin0, op, cfg)
+
+
+if __name__ == "__main__":
+ main()