"""Matrix-free asymmetry probe for the equilibrium-transformer block Jacobian. The state Jacobian J = dF/dz is never materialized. We estimate the growth of T = (S + mu I)^-1 A, where S=(J+J^T)/2 and A=(J-J^T)/2, using autograd JVP/VJP products at the relaxed fixed point. """ import argparse import glob import math import os import pickle import time import warnings warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", message=".*cuBLAS.*", category=UserWarning) warnings.filterwarnings("ignore", message=".*CUBLAS.*", category=UserWarning) from dataclasses import dataclass, field from pathlib import Path import numpy as np import torch import torch.nn.functional as F from scipy.sparse.linalg import LinearOperator, gmres, minres import lt_ep_train as L from lt_ep_train import EQBlock, bptt_step, ce, ep_step, relax def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", default="runs/ep_clean.pt") ap.add_argument("--data", default="data/tinystories_bpe") ap.add_argument("--gelu", default="erf") ap.add_argument("--C", type=int, default=512) ap.add_argument("--H", type=int, default=16) ap.add_argument("--Mm", type=int, default=256) ap.add_argument("--T", type=int, default=256) ap.add_argument("--B", type=int, default=8) ap.add_argument("--T1", type=int, default=150) ap.add_argument("--T2", type=int, default=20) ap.add_argument("--eps", type=float, default=0.1) ap.add_argument("--beta", type=float, default=0.02) ap.add_argument("--t1max", type=int, default=2000) ap.add_argument("--relax-chunk", type=int, default=50) ap.add_argument("--res-est", type=float, default=1e-4) ap.add_argument("--t2sel", type=int, default=40) ap.add_argument("--holo", type=int, default=2) ap.add_argument("--hr", type=float, default=0.02) ap.add_argument("--seed", type=int, default=0) ap.add_argument("--trace-probes", type=int, default=4) ap.add_argument("--mu-scale", type=float, default=1e-3) ap.add_argument("--mu", type=float, default=-1.0, help="override mu; negative means estimate from trace") ap.add_argument("--solve-iters", type=int, default=80) ap.add_argument("--solve-tol", type=float, default=1e-5) ap.add_argument("--adjoint-iters", type=int, default=200) ap.add_argument("--adjoint-tol", type=float, default=1e-5) ap.add_argument("--adjoint-mu", type=float, default=1e-4, help="fallback Tikhonov mu for J^T+muI if GMRES stalls") ap.add_argument("--rho-iters", type=int, default=20) ap.add_argument("--rho-restarts", type=int, default=3) ap.add_argument("--sigma-iters", type=int, default=8) ap.add_argument("--sigma-restarts", type=int, default=2) ap.add_argument("--arnoldi-k", type=int, default=12) ap.add_argument("--skiprho", action=argparse.BooleanOptionalAction, default=True, help="skip rho/sigma spectral probes and run only exact-adjoint gradient comparison") ap.add_argument("--diag", action="store_true", help="run EP/exact-adjoint diagnostic suite and exit") ap.add_argument("--noplot", action="store_true", help=argparse.SUPPRESS) ap.add_argument("--lr", type=float, default=None, help=argparse.SUPPRESS) ap.add_argument("--device", default="cuda", choices=["cuda", "cpu"]) ap.add_argument("--tf32", action="store_true") return ap.parse_args() def resolve_ckpt_path(path): p = Path(path) if p.is_absolute(): return str(p) cwd_path = Path.cwd() / p if cwd_path.exists(): return str(cwd_path) return str(Path(__file__).resolve().parent / p) def require_cuda_if_requested(device): if device != "cuda": return visible = os.environ.get("CUDA_VISIBLE_DEVICES") ok = torch.cuda.is_available() and torch.cuda.device_count() > 0 if ok: torch.cuda.set_device(0) return print("ERROR: CUDA unavailable; requested GPU0 run cannot start.", flush=True) print(f"CUDA_VISIBLE_DEVICES={visible!r}", flush=True) print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda}", flush=True) print(f"torch.cuda.is_available()={torch.cuda.is_available()} device_count={torch.cuda.device_count()}", flush=True) nodes = glob.glob("/dev/nvidia*") required = ["/dev/nvidiactl", "/dev/nvidia-uvm", "/dev/nvidia0"] missing = [p for p in required if not os.path.exists(p)] print(f"/dev/nvidia*={' '.join(nodes) if nodes else 'MISSING'}", flush=True) print(f"missing CUDA device nodes={' '.join(missing) if missing else 'none'}", flush=True) raise SystemExit(2) def build_block(cfg, dev): # Same construction and checkpoint-copy path as resreg_probe.py. L.dev = dev L.DD = Path(cfg.data) L.vocab = pickle.load(open(L.DD / "meta.pkl", "rb"))["vocab_size"] torch.manual_seed(cfg.seed) blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=1.0, attn_mode="thick") blk.qknorm = True blk.fnoise = 0.0 blk._cstep = None blk.navg = 1 blk.li_avg = 0 blk.track = True blk.nbrake = 0.0 blk.gelu = cfg.gelu ck = torch.load(cfg.ckpt, map_location=dev) with torch.no_grad(): for p, w in zip(blk.allp, ck["allp"]): p.copy_(w.to(dev)) return blk, ck @torch.no_grad() def residuals(blk, z, xin, eps): z1 = relax(blk, z, xin, 1, eps) zn = z.norm().item() + 1e-12 step_rel = (z1 - z).norm().item() / zn force_rel = blk.tforce(z, xin).norm().item() / zn return step_rel, force_rel def relax_to_fixed_point(blk, xin, cfg): z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) step_rel, force_rel = residuals(blk, z, xin, cfg.eps) steps = cfg.T1 while steps < cfg.t1max and step_rel > cfg.res_est: chunk = min(cfg.relax_chunk, cfg.t1max - steps) z = relax(blk, z, xin, chunk, cfg.eps) steps += chunk step_rel, force_rel = residuals(blk, z, xin, cfg.eps) print(f"relax steps={steps:4d} step_res={step_rel:.3e} force_res={force_rel:.3e}", flush=True) return z.detach(), steps, step_rel, force_rel def dot(a, b): return torch.dot(a.reshape(-1), b.reshape(-1)) def norm(a): return torch.linalg.vector_norm(a.reshape(-1)) def block_param_list(blk): if hasattr(blk.block, "parameters"): return list(blk.block.parameters()) return list(blk.block) def flat_grad_by_param_id(grads, params): flat = [] for p in params: g = grads.get(id(p)) if grads is not None else None if g is None: g = torch.zeros_like(p, device="cpu", dtype=torch.float64) else: g = g.detach().to(device="cpu", dtype=torch.float64) flat.append(g.reshape(-1)) return torch.cat(flat) def set_param_requires_grad(blk, value): for p in blk.allp: p.requires_grad_(value) def cos(a, b): return (torch.dot(a, b) / (norm(a) * norm(b) + 1e-20)).item() def rel_diff(a, b): return (norm(a - b) / (norm(b) + 1e-20)).item() def unit_rand(shape, dev, dtype): v = torch.randn(shape, device=dev, dtype=dtype) return v / (norm(v) + 1e-30) @dataclass class SolveLog: residuals: list = field(default_factory=list) infos: list = field(default_factory=list) iters: list = field(default_factory=list) def add(self, rel_res, info, nit): self.residuals.append(float(rel_res)) self.infos.append(int(info)) self.iters.append(int(nit)) def summary(self): if not self.residuals: return "solve residuals: none" r = np.asarray(self.residuals, dtype=np.float64) it = np.asarray(self.iters, dtype=np.int64) bad = sum(1 for x in self.infos if x != 0) return ( f"solve residuals: count={len(r)} min={r.min():.3e} " f"median={np.median(r):.3e} max={r.max():.3e} " f"iters median={np.median(it):.0f} max={it.max()} nonzero_info={bad}" ) class Operators: def __init__(self, blk, zstar, xin, cfg, mu): self.blk = blk self.zstar = zstar.detach() self.xin = xin.detach() self.shape = tuple(zstar.shape) self.n = zstar.numel() self.dev = zstar.device self.dtype = zstar.dtype self.cfg = cfg self.mu = float(mu) self.solve_log = SolveLog() def f(self, z): return self.blk.tforce(z, self.xin) def jv(self, v): with torch.enable_grad(): _, out = torch.autograd.functional.jvp( self.f, self.zstar, v.contiguous(), create_graph=False, strict=False ) return out.detach() def jtv(self, v): with torch.enable_grad(): z = self.zstar.detach().requires_grad_(True) fz = self.f(z) (g,) = torch.autograd.grad(fz, z, grad_outputs=v.contiguous(), create_graph=False, retain_graph=False) return g.detach() def sv(self, v): jv = self.jv(v) jtv = self.jtv(v) return 0.5 * (jv + jtv) def av(self, v): jv = self.jv(v) jtv = self.jtv(v) return 0.5 * (jv - jtv) def smu(self, v, mu=None): m = self.mu if mu is None else float(mu) return self.sv(v) + m * v def _from_numpy(self, x): x = np.asarray(x, dtype=np.float32) return torch.from_numpy(x).to(device=self.dev, dtype=self.dtype).view(self.shape) def _to_numpy(self, x): return x.detach().reshape(-1).float().cpu().numpy() def solve_s(self, rhs, mu=None, tag=""): m = self.mu if mu is None else float(mu) b = self._to_numpy(rhs) counter = {"n": 0} def matvec(x_np): x = self._from_numpy(x_np) y = self.smu(x, m) return self._to_numpy(y) def cb(_x): counter["n"] += 1 Aop = LinearOperator((self.n, self.n), matvec=matvec, dtype=np.dtype("float32")) x_np, info = minres(Aop, b, rtol=self.cfg.solve_tol, maxiter=self.cfg.solve_iters, callback=cb, check=False) x = self._from_numpy(x_np).detach() rel = (norm(self.smu(x, m) - rhs) / (norm(rhs) + 1e-30)).item() self.solve_log.add(rel, info, counter["n"]) if tag: print(f"solve {tag}: mu={m:.3e} rel_res={rel:.3e} iters={counter['n']} info={info}", flush=True) return x, rel, info, counter["n"] def solve_jt_gmres(self, rhs, tol, maxiter, mu=0.0, tag="adjoint"): m = float(mu) b = self._to_numpy(rhs) counter = {"n": 0} restart = max(1, min(50, int(maxiter))) def matvec(x_np): x = self._from_numpy(x_np) y = self.jtv(x) if m != 0.0: y = y + m * x return self._to_numpy(y) def cb(_arg): counter["n"] += 1 Aop = LinearOperator((self.n, self.n), matvec=matvec, dtype=np.dtype("float32")) try: x_np, info = gmres( Aop, b, rtol=tol, atol=0.0, restart=restart, maxiter=int(maxiter), callback=cb, callback_type="legacy", ) except TypeError: x_np, info = gmres(Aop, b, tol=tol, restart=restart, maxiter=int(maxiter), callback=cb) x = self._from_numpy(x_np).detach() rel = (norm(self.jtv(x) + m * x - rhs) / (norm(rhs) + 1e-30)).item() print(f"GMRES {tag}: mu={m:.3e} rel_res={rel:.3e} iters={counter['n']} info={info}", flush=True) return x, rel, info, counter["n"] def t(self, v): rhs = self.av(v) x, _, _, _ = self.solve_s(rhs) return x def tt(self, u): y, _, _, _ = self.solve_s(u) return -self.av(y) def estimate_trace_s(op, probes): vals = [] for i in range(probes): r = torch.randint(0, 2, op.shape, device=op.dev, dtype=torch.int8).to(op.dtype) r = r.mul_(2).sub_(1) sr = op.sv(r) vals.append((dot(r, sr) / op.n).item()) print(f"trace probe {i}: tr(S)/n={vals[-1]:+.6e}", flush=True) return float(np.mean(vals)), float(np.std(vals) if len(vals) > 1 else 0.0) def sensitivity_probe(op, mu): v = unit_rand(op.shape, op.dev, op.dtype) rhs = op.av(v) xb, rb, _, _ = op.solve_s(rhs, mu=mu, tag="sensitivity/base") rows = [] for scale in (0.1, 10.0): ms = max(mu * scale, 0.0) xa, ra, _, _ = op.solve_s(rhs, mu=ms, tag=f"sensitivity/mu_x{scale:g}") rel_dx = (norm(xa - xb) / (norm(xb) + 1e-30)).item() rows.append((scale, ms, rel_dx, ra)) print( "solve sensitivity: " + " ".join(f"mu_x{scale:g}: rel_dx={dx:.3e} rel_res={rr:.3e}" for scale, _, dx, rr in rows), flush=True, ) return rb, rows def power_rho(op, cfg): best = 0.0 best_hist = None for r in range(cfg.rho_restarts): v = unit_rand(op.shape, op.dev, op.dtype) hist = [] for i in range(cfg.rho_iters): w = op.t(v) growth = norm(w).item() rq = (dot(v, w) / (dot(v, v) + 1e-30)).item() hist.append((growth, rq)) if growth <= 1e-30 or not math.isfinite(growth): break v = (w / growth).detach() print(f"rho restart={r} iter={i + 1:02d} growth={growth:.6e} rayleigh={rq:+.6e}", flush=True) if hist and hist[-1][0] > best: best = hist[-1][0] best_hist = hist if best_hist: trend = " ".join(f"{g:.3g}" for g, _ in best_hist[-min(6, len(best_hist)):]) rtrend = " ".join(f"{rq:+.3g}" for _, rq in best_hist[-min(6, len(best_hist)):]) print(f"rho power trend last={trend}", flush=True) print(f"rho Rayleigh trend last={rtrend}", flush=True) return best def arnoldi_rho(op, k): if k <= 0: return None q = unit_rand(op.shape, op.dev, op.dtype) Q = [q] H = np.zeros((k + 1, k), dtype=np.float64) m = 0 for j in range(k): w = op.t(Q[j]) for i in range(j + 1): hij = dot(Q[i], w).item() H[i, j] = hij w = w - hij * Q[i] hnext = norm(w).item() H[j + 1, j] = hnext m = j + 1 print(f"arnoldi iter={j + 1:02d} h_next={hnext:.6e}", flush=True) if hnext < 1e-12: break if j + 1 < k: Q.append((w / hnext).detach()) eig = np.linalg.eigvals(H[:m, :m]) rho = float(np.max(np.abs(eig))) if eig.size else float("nan") print(f"rho Arnoldi(k={m})={rho:.6e}", flush=True) return rho def power_sigma(op, cfg): best = 0.0 for r in range(cfg.sigma_restarts): v = unit_rand(op.shape, op.dev, op.dtype) sigma = 0.0 for i in range(cfg.sigma_iters): u = op.t(v) sigma = norm(u).item() w = op.tt(u) wn = norm(w).item() if wn <= 1e-30 or not math.isfinite(wn): break v = (w / wn).detach() print(f"sigma restart={r} iter={i + 1:02d} sigma={sigma:.6e}", flush=True) best = max(best, sigma) return best def ce_state_grad(blk, zstar, y): with torch.enable_grad(): z = zstar.detach().requires_grad_(True) loss = ce(blk, z, y) (ell,) = torch.autograd.grad(loss, z) return ell.detach(), float(loss.detach()) def solve_exact_adjoint(op, ell, cfg): rhs = -ell.detach() lam, rel, info, nit = op.solve_jt_gmres(rhs, cfg.adjoint_tol, cfg.adjoint_iters, mu=0.0, tag="J^T lambda=-ell") stalled = (info != 0) or (not math.isfinite(rel)) or (rel > max(10.0 * cfg.adjoint_tol, 1e-4)) mu_used = 0.0 if stalled: mu_used = max(float(cfg.adjoint_mu), 1e-8) print(f"GMRES stalled; retrying exact-adjoint solve with Tikhonov J^T+muI, mu={mu_used:.3e}", flush=True) lam, rel, info, nit = op.solve_jt_gmres( rhs, cfg.adjoint_tol, cfg.adjoint_iters, mu=mu_used, tag="(J^T+muI) lambda=-ell" ) return lam.detach(), rel, info, nit, mu_used def exact_transpose_grad(blk, idx, zstar, xin0, lam, params): for p in blk.allp: p.requires_grad_(True) with torch.enable_grad(): # Value stays at the relaxed clamp xin0, while tok/pos receive the same clamp-gradient path as the trainer. xin = xin0 + (blk.embed(idx) - blk.embed(idx).detach()) force = blk.tforce(zstar.detach(), xin) grads = torch.autograd.grad((force * lam.detach()).sum(), params, allow_unused=True) return {id(p): g for p, g in zip(params, grads)} def run_ep_step_flat(blk, idx, y, cfg, params, *, beta=None, holo=None, hr=None, t2sel=None, track=None, T2=None): saved_track = getattr(blk, "track", None) if track is not None: blk.track = bool(track) try: set_param_requires_grad(blk, True) # Mirrors lt_ep_train.ep_step: # (blk, idx, y, T1, T2, eps, beta, jacreg, holo, hr, t1max, res_est, t2sel, corr_every, res_gate, resreg). grads, ep_res = ep_step( blk, idx, y, cfg.T1, cfg.T2 if T2 is None else int(T2), cfg.eps, cfg.beta if beta is None else float(beta), 0.0, cfg.holo if holo is None else int(holo), cfg.hr if hr is None else float(hr), cfg.t1max, cfg.res_est, cfg.t2sel if t2sel is None else int(t2sel), 1, 0.0, 0.0, ) return flat_grad_by_param_id(grads, params), float(ep_res) finally: if track is not None and saved_track is not None: blk.track = saved_track @torch.no_grad() def fixed_point_step_abs(blk, zstar, xin, eps): return (relax(blk, zstar, xin, 1, eps) - zstar).norm().item() def exact_reference_for_batch(blk, idx, y, cfg, label, compute_bptt=True): print(f"--- exact reference: {label} ---", flush=True) xin0 = blk.embed(idx).detach() zstar, steps, step_res, force_res = relax_to_fixed_point(blk, xin0, cfg) step_abs = fixed_point_step_abs(blk, zstar, xin0, cfg.eps) print( f"{label}: z* residual step_abs={step_abs:.6e} step_rel={step_res:.6e} " f"force_rel={force_res:.6e} relax_steps={steps}", flush=True, ) if step_res > cfg.res_est: print(f"{label}: WARNING step_res={step_res:.3e} > res_est={cfg.res_est:.3e}", flush=True) set_param_requires_grad(blk, False) op = Operators(blk, zstar, xin0, cfg, mu=0.0) ell, ce_loss = ce_state_grad(blk, zstar, y) print(f"{label}: CE(z*)={ce_loss:.6f} ||ell||={norm(ell).item():.6e}", flush=True) lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg) print( f"{label}: adjoint residual={gmres_rel:.3e} iters={gmres_iters} info={gmres_info} " f"tikhonov_mu={adj_mu:.3e}", flush=True, ) params = block_param_list(blk) gt = flat_grad_by_param_id(exact_transpose_grad(blk, idx, zstar, xin0, lam, params), params) out = { "idx": idx, "y": y, "params": params, "gt": gt, "z_step_abs": step_abs, "z_step_rel": step_res, "z_force_rel": force_res, "relax_steps": steps, "gmres_rel": gmres_rel, "gmres_info": gmres_info, "gmres_iters": gmres_iters, "adj_mu": adj_mu, "ce_loss": ce_loss, } if compute_bptt: set_param_requires_grad(blk, True) gB = bptt_step(blk, idx, y, cfg.T1, cfg.eps, 0.0) out["gBv"] = flat_grad_by_param_id(gB, params) set_param_requires_grad(blk, True) return out def draw_seeded_train_batch(cfg, seed): torch.manual_seed(int(seed)) return L.get_batch("train", cfg.B, cfg.T) def finite_range(vals): ok = [float(v) for v in vals if v is not None and math.isfinite(float(v))] if not ok: return None return float(np.mean(ok)), float(np.min(ok)), float(np.max(ok)) def read_multi_batch(rows): vals = [r.get("cos_ep_t") for r in rows if r.get("ok")] stats = finite_range(vals) if stats is None: return "no successful batches" mean, mn, mx = stats spread = mx - mn if mn > 0.95: return "consistently aligned across batches" if mx < 0.80: return "systematically low across batches" if spread > 0.20: return "batch-variance/outlier behavior is material" return "mostly systematic with moderate batch variance" def read_beta_sweep(rows): ok = [(r["beta"], r["cos"]) for r in rows if r.get("ok")] if not ok: return "no successful beta points" first_beta, first_cos = ok[0] last_beta, last_cos = ok[-1] best_cos = max(c for _, c in ok) if last_cos > 0.95 and last_cos - first_cos > 0.10: return f"finite-beta bias likely: cos improves from beta={first_beta:g} to beta={last_beta:g}" if best_cos < 0.80: return "cos stays low as beta shrinks: structural/bug more likely than finite-beta bias" if last_cos > first_cos + 0.05: return "some finite-beta sensitivity, but not a clean convergence-to-1 result" return "no strong beta-to-zero improvement" def read_ablation(rows): ok = [r for r in rows if r.get("ok")] if not ok: return "no successful ablations" full = next((r for r in ok if r["key"] == "full"), None) best = max(ok, key=lambda r: r["cos"]) if full is None: return f"best successful config is {best['label']}" delta = best["cos"] - full["cos"] if delta <= 0.05: return "no ablation materially improves over FULL" if best["key"] == "track_off": return "tracking path is suspect: disabling blk.track improved cos" if best["key"] == "plain": return "holomorphic/adaptive path is suspect: plain real EP improved cos" if best["key"] == "fixed_t2": return "adaptive-T2 selection/tracking is suspect: fixed T2 improved cos" return f"{best['label']} is the strongest improvement over FULL" def print_diagnostic_summary(multi_rows, beta_rows, ablation_rows): print("", flush=True) print("================ DIAGNOSTIC SUMMARY ================", flush=True) multi_stats_t = finite_range([r.get("cos_ep_t") for r in multi_rows if r.get("ok")]) multi_stats_b = finite_range([r.get("cos_ep_b") for r in multi_rows if r.get("ok")]) if multi_stats_t is None: print("Multi-batch: no successful batches", flush=True) else: mean, mn, mx = multi_stats_t print(f"Multi-batch: mean cos(g_EP,g_transpose)={mean:+.6f} range=[{mn:+.6f}, {mx:+.6f}]", flush=True) if multi_stats_b is not None: mean, mn, mx = multi_stats_b print(f"Multi-batch: mean cos(g_EP,g_BPTT)={mean:+.6f} range=[{mn:+.6f}, {mx:+.6f}]", flush=True) print(f"Multi-batch read: {read_multi_batch(multi_rows)}", flush=True) print("Beta sweep (beta | cos(g_EP,g_transpose)):", flush=True) if beta_rows: for row in beta_rows: if row.get("ok"): print(f" {row['beta']:<8g} | {row['cos']:+.6f}", flush=True) else: print(f" {row.get('beta', 'n/a')!s:<8} | failed: {row.get('error')}", flush=True) else: print(" none", flush=True) print(f"Beta sweep read: {read_beta_sweep(beta_rows)}", flush=True) print("Ablation (config | cos(g_EP,g_transpose)):", flush=True) if ablation_rows: for row in ablation_rows: if row.get("ok"): print(f" {row['label']} | {row['cos']:+.6f}", flush=True) else: print(f" {row.get('label', 'unknown')} | failed: {row.get('error')}", flush=True) else: print(" none", flush=True) print(f"Ablation read: {read_ablation(ablation_rows)}", flush=True) print("============== END DIAGNOSTIC SUMMARY ==============", flush=True) def run_diagnostics(blk, cfg, ck): print("=== DIAGNOSTIC MODE ===", flush=True) print(f"# ckpt step {ck.get('step')} best {ck.get('best')}", flush=True) print( "ep_step paths: holo=2,t2sel>0,track=True -> holo_a_track; " "holo=2,t2sel>0,track=False -> holo_a_select2; holo>0,t2sel=0 -> holo_a; holo=0 -> plain EP", flush=True, ) print("gradient comparison scope: blk.block parameters; readout Wh is excluded", flush=True) multi_rows = [] beta_rows = [] ablation_rows = [] seed1000_ref = None print("=== DIAGNOSTIC 1: MULTI-BATCH ===", flush=True) for i in range(6): seed = 1000 + i label = f"diag1 batch={i} seed={seed}" try: idx, y = draw_seeded_train_batch(cfg, seed) ref = exact_reference_for_batch(blk, idx, y, cfg, label, compute_bptt=True) torch.manual_seed(seed) gEPv, ep_res = run_ep_step_flat(blk, idx, y, cfg, ref["params"]) row = { "ok": True, "batch": i, "seed": seed, "cos_ep_t": cos(gEPv, ref["gt"]), "cos_ep_b": cos(gEPv, ref["gBv"]), "cos_t_b": cos(ref["gt"], ref["gBv"]), "z_step_abs": ref["z_step_abs"], "z_step_rel": ref["z_step_rel"], "z_force_rel": ref["z_force_rel"], "ep_res": ep_res, } multi_rows.append(row) print( f"{label}: cos(g_EP,g_transpose)={row['cos_ep_t']:+.6f} " f"cos(g_EP,g_BPTT)={row['cos_ep_b']:+.6f} " f"cos(g_transpose,g_BPTT)={row['cos_t_b']:+.6f} " f"z_res_abs={row['z_step_abs']:.6e} z_res_rel={row['z_step_rel']:.6e} ep_res={ep_res:.6e}", flush=True, ) if seed == 1000: seed1000_ref = ref except Exception as err: row = {"ok": False, "batch": i, "seed": seed, "error": repr(err)} multi_rows.append(row) print(f"{label} failed: {err!r}", flush=True) multi_stats_t = finite_range([r.get("cos_ep_t") for r in multi_rows if r.get("ok")]) multi_stats_b = finite_range([r.get("cos_ep_b") for r in multi_rows if r.get("ok")]) if multi_stats_t is not None and multi_stats_b is not None: mt, mint, maxt = multi_stats_t mb, minb, maxb = multi_stats_b print( f"DIAG1 aggregate: cos(g_EP,g_transpose) mean={mt:+.6f} min={mint:+.6f} max={maxt:+.6f}; " f"cos(g_EP,g_BPTT) mean={mb:+.6f} min={minb:+.6f} max={maxb:+.6f}", flush=True, ) if seed1000_ref is None: try: idx, y = draw_seeded_train_batch(cfg, 1000) seed1000_ref = exact_reference_for_batch(blk, idx, y, cfg, "diag seed=1000 fallback", compute_bptt=True) except Exception as err: print(f"seed=1000 reference failed; beta sweep and ablation cannot run: {err!r}", flush=True) print("=== DIAGNOSTIC 2: BETA SWEEP ===", flush=True) if seed1000_ref is not None: for beta in [0.04, 0.02, 0.01, 0.005, 0.002]: try: torch.manual_seed(1000) gEPv, ep_res = run_ep_step_flat( blk, seed1000_ref["idx"], seed1000_ref["y"], cfg, seed1000_ref["params"], beta=beta, hr=beta, ) row = {"ok": True, "beta": beta, "cos": cos(gEPv, seed1000_ref["gt"]), "ep_res": ep_res} beta_rows.append(row) print(f"beta={beta:g} hr={beta:g}: cos(g_EP,g_transpose)={row['cos']:+.6f} ep_res={ep_res:.6e}", flush=True) except Exception as err: beta_rows.append({"ok": False, "beta": beta, "error": repr(err)}) print(f"beta={beta:g} failed: {err!r}", flush=True) else: print("DIAG2 skipped: seed=1000 reference unavailable", flush=True) print("=== DIAGNOSTIC 3: COMPONENT ABLATION ===", flush=True) if seed1000_ref is not None: ablations = [ { "key": "full", "label": "FULL holo=2 track=True t2sel=40", "kwargs": {"holo": 2, "track": True, "t2sel": 40, "hr": cfg.hr, "beta": cfg.beta}, }, { "key": "track_off", "label": "holo=2 track=False t2sel=40", "kwargs": {"holo": 2, "track": False, "t2sel": 40, "hr": cfg.hr, "beta": cfg.beta}, }, { "key": "plain", "label": "plain EP holo=0 track ignored t2sel=0", "kwargs": {"holo": 0, "track": getattr(blk, "track", False), "t2sel": 0, "hr": cfg.hr, "beta": cfg.beta}, }, { "key": "fixed_t2", "label": f"holo=2 track=True t2sel=0 fixed T2={cfg.T2}", "kwargs": {"holo": 2, "track": True, "t2sel": 0, "hr": cfg.hr, "beta": cfg.beta, "T2": cfg.T2}, }, ] for item in ablations: try: torch.manual_seed(1000) gEPv, ep_res = run_ep_step_flat( blk, seed1000_ref["idx"], seed1000_ref["y"], cfg, seed1000_ref["params"], **item["kwargs"], ) row = { "ok": True, "key": item["key"], "label": item["label"], "cos": cos(gEPv, seed1000_ref["gt"]), "ep_res": ep_res, } ablation_rows.append(row) print(f"{item['label']}: cos(g_EP,g_transpose)={row['cos']:+.6f} ep_res={ep_res:.6e}", flush=True) except Exception as err: row = {"ok": False, "key": item["key"], "label": item["label"], "error": repr(err)} ablation_rows.append(row) print(f"config {item['label']} failed: {err!r}", flush=True) else: print("DIAG3 skipped: seed=1000 reference unavailable", flush=True) print_diagnostic_summary(multi_rows, beta_rows, ablation_rows) def compare_exact_adjoint(blk, idx, y, zstar, xin0, op, cfg): print("=== exact-adjoint gradient comparison ===", flush=True) ell, ce_loss = ce_state_grad(blk, zstar, y) print(f"CE(z*)={ce_loss:.6f} ||ell||={norm(ell).item():.6e}", flush=True) lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg) print( f"adjoint solve summary: residual={gmres_rel:.3e} iters={gmres_iters} info={gmres_info} " f"tikhonov_mu={adj_mu:.3e}", flush=True, ) params = block_param_list(blk) gt = flat_grad_by_param_id(exact_transpose_grad(blk, idx, zstar, xin0, lam, params), params) print("gradient comparison scope: blk.block parameters; readout Wh is excluded", flush=True) gB = bptt_step(blk, idx, y, cfg.T1, cfg.eps, 0.0) gEP, ep_res = ep_step( blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, 0.0, cfg.holo, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel, 1, 0.0, ) gBv = flat_grad_by_param_id(gB, params) gEPv = flat_grad_by_param_id(gEP, params) print(f"EP estimator free-phase residual from ep_step={ep_res:.6e}", flush=True) print(f"||g_transpose||={norm(gt).item():.6e} ||g_BPTT||={norm(gBv).item():.6e} ||g_EP||={norm(gEPv).item():.6e}", flush=True) c_t_b = cos(gt, gBv) d_t_b = rel_diff(gt, gBv) c_ep_t = cos(gEPv, gt) d_ep_t = rel_diff(gEPv, gt) c_ep_b = cos(gEPv, gBv) d_ep_b = rel_diff(gEPv, gBv) print(f"cos(g_transpose, g_BPTT)={c_t_b:+.6f} ||g_transpose-g_BPTT||/||g_BPTT||={d_t_b:.6e}", flush=True) print(f"cos(g_EP, g_transpose)={c_ep_t:+.6f} ||g_EP-g_transpose||/||g_transpose||={d_ep_t:.6e}", flush=True) print(f"cos(g_EP, g_BPTT)={c_ep_b:+.6f} ||g_EP-g_BPTT||/||g_BPTT||={d_ep_b:.6e}", flush=True) print("interpretation:", flush=True) print(" cos(g_transpose,g_BPTT)~1 AND cos(g_EP,g_transpose)~1 -> our EP IS the exact adjoint; failure is convergence/contraction", flush=True) print(" cos(g_transpose,g_BPTT)~1 AND cos(g_EP,g_transpose)<1 -> exact adjoint works, our EP falls short -> implement exact/dyadic", flush=True) print(" cos(g_transpose,g_BPTT)<1 -> even exact adjoint != BPTT -> finite-time/convergence, not the adjoint", flush=True) def main(): cfg = parse_args() cfg.ckpt = resolve_ckpt_path(cfg.ckpt) require_cuda_if_requested(cfg.device) dev = torch.device("cuda:0" if cfg.device == "cuda" else "cpu") torch.backends.cuda.matmul.allow_tf32 = bool(cfg.tf32) torch.backends.cudnn.allow_tf32 = bool(cfg.tf32) print(f"# asym_probe device={dev} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r}", flush=True) print( f"# ckpt={cfg.ckpt} B={cfg.B} T={cfg.T} C={cfg.C} H={cfg.H} Mm={cfg.Mm} " f"attn_mode=thick qknorm=True gelu={cfg.gelu}", flush=True, ) blk, ck = build_block(cfg, dev) if cfg.diag: run_diagnostics(blk, cfg, ck) return idx, y = L.get_batch("train", cfg.B, cfg.T) xin0 = blk.embed(idx).detach() zstar, steps, step_res, force_res = relax_to_fixed_point(blk, xin0, cfg) print(f"# ckpt step {ck.get('step')} best {ck.get('best')}", flush=True) print(f"z* residual: step_res={step_res:.6e} force_res={force_res:.6e} relax_steps={steps}", flush=True) if step_res > cfg.res_est: print(f"WARNING: fixed-point target not reached: step_res={step_res:.3e} > {cfg.res_est:.3e}", flush=True) if step_res > 1e-3 or force_res > 1e-3: print("WARNING: relaxed z* residual exceeds 1e-3; do not trust exact-adjoint solves until convergence improves", flush=True) # Freeze parameters for state Jacobian products. tforce is out-of-place; each # VJP re-leafs z* to avoid stale graphs, and xin0 is held detached/fixed. set_param_requires_grad(blk, False) print("autograd note: using blk.tforce directly; no in-place tforce ops patched; z* is re-leafed per VJP/JVP", flush=True) op0 = Operators(blk, zstar, xin0, cfg, mu=0.0) if cfg.skiprho: compare_exact_adjoint(blk, idx, y, zstar, xin0, op0, cfg) return tr_mean, tr_std = estimate_trace_s(op0, cfg.trace_probes) if cfg.mu >= 0: mu = float(cfg.mu) else: mu = cfg.mu_scale * max(abs(tr_mean), 1e-12) print(f"trace(S)/n estimate={tr_mean:+.6e} std={tr_std:.3e}", flush=True) print(f"mu used={mu:.6e} (mu_scale={cfg.mu_scale:g}, solve operator S+muI)", flush=True) op = Operators(blk, zstar, xin0, cfg, mu=mu) sensitivity_probe(op, mu) t0 = time.time() rho_power = power_rho(op, cfg) rho_arnoldi = arnoldi_rho(op, cfg.arnoldi_k) sigma = power_sigma(op, cfg) elapsed = time.time() - t0 rho = max(rho_power, rho_arnoldi if rho_arnoldi is not None else 0.0) print(op.solve_log.summary(), flush=True) print("non-normal note: power iteration reports dominant growth; Rayleigh trend may be small/oscillatory for skew modes", flush=True) print(f"rho(S^-1 A)={rho:.6e} power={rho_power:.6e} arnoldi={rho_arnoldi if rho_arnoldi is not None else float('nan'):.6e}", flush=True) print(f"||S^-1 A||_2={sigma:.6e}", flush=True) print(f"elapsed_operator_seconds={elapsed:.1f}", flush=True) verdict = "higher-order AEP viable" if rho < 1.0 else "higher-order AEP not viable" print(f"VERDICT: rho {'<' if rho < 1.0 else '>='} 1 => {verdict}", flush=True) compare_exact_adjoint(blk, idx, y, zstar, xin0, op, cfg) if __name__ == "__main__": main()