diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/asym_probe.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/asym_probe.py')
| -rw-r--r-- | ep_run/asym_probe.py | 922 |
1 files changed, 922 insertions, 0 deletions
diff --git a/ep_run/asym_probe.py b/ep_run/asym_probe.py new file mode 100644 index 0000000..1b61354 --- /dev/null +++ b/ep_run/asym_probe.py @@ -0,0 +1,922 @@ +"""Matrix-free asymmetry probe for the equilibrium-transformer block Jacobian. + +The state Jacobian J = dF/dz is never materialized. We estimate the growth of +T = (S + mu I)^-1 A, where S=(J+J^T)/2 and A=(J-J^T)/2, using autograd JVP/VJP +products at the relaxed fixed point. +""" +import argparse +import glob +import math +import os +import pickle +import time +import warnings + +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", message=".*cuBLAS.*", category=UserWarning) +warnings.filterwarnings("ignore", message=".*CUBLAS.*", category=UserWarning) + +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.sparse.linalg import LinearOperator, gmres, minres + +import lt_ep_train as L +from lt_ep_train import EQBlock, bptt_step, ce, ep_step, relax + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt", default="runs/ep_clean.pt") + ap.add_argument("--data", default="data/tinystories_bpe") + ap.add_argument("--gelu", default="erf") + ap.add_argument("--C", type=int, default=512) + ap.add_argument("--H", type=int, default=16) + ap.add_argument("--Mm", type=int, default=256) + ap.add_argument("--T", type=int, default=256) + ap.add_argument("--B", type=int, default=8) + ap.add_argument("--T1", type=int, default=150) + ap.add_argument("--T2", type=int, default=20) + ap.add_argument("--eps", type=float, default=0.1) + ap.add_argument("--beta", type=float, default=0.02) + ap.add_argument("--t1max", type=int, default=2000) + ap.add_argument("--relax-chunk", type=int, default=50) + ap.add_argument("--res-est", type=float, default=1e-4) + ap.add_argument("--t2sel", type=int, default=40) + ap.add_argument("--holo", type=int, default=2) + ap.add_argument("--hr", type=float, default=0.02) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--trace-probes", type=int, default=4) + ap.add_argument("--mu-scale", type=float, default=1e-3) + ap.add_argument("--mu", type=float, default=-1.0, help="override mu; negative means estimate from trace") + ap.add_argument("--solve-iters", type=int, default=80) + ap.add_argument("--solve-tol", type=float, default=1e-5) + ap.add_argument("--adjoint-iters", type=int, default=200) + ap.add_argument("--adjoint-tol", type=float, default=1e-5) + ap.add_argument("--adjoint-mu", type=float, default=1e-4, help="fallback Tikhonov mu for J^T+muI if GMRES stalls") + ap.add_argument("--rho-iters", type=int, default=20) + ap.add_argument("--rho-restarts", type=int, default=3) + ap.add_argument("--sigma-iters", type=int, default=8) + ap.add_argument("--sigma-restarts", type=int, default=2) + ap.add_argument("--arnoldi-k", type=int, default=12) + ap.add_argument("--skiprho", action=argparse.BooleanOptionalAction, default=True, + help="skip rho/sigma spectral probes and run only exact-adjoint gradient comparison") + ap.add_argument("--diag", action="store_true", help="run EP/exact-adjoint diagnostic suite and exit") + ap.add_argument("--noplot", action="store_true", help=argparse.SUPPRESS) + ap.add_argument("--lr", type=float, default=None, help=argparse.SUPPRESS) + ap.add_argument("--device", default="cuda", choices=["cuda", "cpu"]) + ap.add_argument("--tf32", action="store_true") + return ap.parse_args() + + +def resolve_ckpt_path(path): + p = Path(path) + if p.is_absolute(): + return str(p) + cwd_path = Path.cwd() / p + if cwd_path.exists(): + return str(cwd_path) + return str(Path(__file__).resolve().parent / p) + + +def require_cuda_if_requested(device): + if device != "cuda": + return + visible = os.environ.get("CUDA_VISIBLE_DEVICES") + ok = torch.cuda.is_available() and torch.cuda.device_count() > 0 + if ok: + torch.cuda.set_device(0) + return + print("ERROR: CUDA unavailable; requested GPU0 run cannot start.", flush=True) + print(f"CUDA_VISIBLE_DEVICES={visible!r}", flush=True) + print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda}", flush=True) + print(f"torch.cuda.is_available()={torch.cuda.is_available()} device_count={torch.cuda.device_count()}", flush=True) + nodes = glob.glob("/dev/nvidia*") + required = ["/dev/nvidiactl", "/dev/nvidia-uvm", "/dev/nvidia0"] + missing = [p for p in required if not os.path.exists(p)] + print(f"/dev/nvidia*={' '.join(nodes) if nodes else 'MISSING'}", flush=True) + print(f"missing CUDA device nodes={' '.join(missing) if missing else 'none'}", flush=True) + raise SystemExit(2) + + +def build_block(cfg, dev): + # Same construction and checkpoint-copy path as resreg_probe.py. + L.dev = dev + L.DD = Path(cfg.data) + L.vocab = pickle.load(open(L.DD / "meta.pkl", "rb"))["vocab_size"] + torch.manual_seed(cfg.seed) + blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=1.0, attn_mode="thick") + blk.qknorm = True + blk.fnoise = 0.0 + blk._cstep = None + blk.navg = 1 + blk.li_avg = 0 + blk.track = True + blk.nbrake = 0.0 + blk.gelu = cfg.gelu + ck = torch.load(cfg.ckpt, map_location=dev) + with torch.no_grad(): + for p, w in zip(blk.allp, ck["allp"]): + p.copy_(w.to(dev)) + return blk, ck + + +@torch.no_grad() +def residuals(blk, z, xin, eps): + z1 = relax(blk, z, xin, 1, eps) + zn = z.norm().item() + 1e-12 + step_rel = (z1 - z).norm().item() / zn + force_rel = blk.tforce(z, xin).norm().item() / zn + return step_rel, force_rel + + +def relax_to_fixed_point(blk, xin, cfg): + z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) + step_rel, force_rel = residuals(blk, z, xin, cfg.eps) + steps = cfg.T1 + while steps < cfg.t1max and step_rel > cfg.res_est: + chunk = min(cfg.relax_chunk, cfg.t1max - steps) + z = relax(blk, z, xin, chunk, cfg.eps) + steps += chunk + step_rel, force_rel = residuals(blk, z, xin, cfg.eps) + print(f"relax steps={steps:4d} step_res={step_rel:.3e} force_res={force_rel:.3e}", flush=True) + return z.detach(), steps, step_rel, force_rel + + +def dot(a, b): + return torch.dot(a.reshape(-1), b.reshape(-1)) + + +def norm(a): + return torch.linalg.vector_norm(a.reshape(-1)) + + +def block_param_list(blk): + if hasattr(blk.block, "parameters"): + return list(blk.block.parameters()) + return list(blk.block) + + +def flat_grad_by_param_id(grads, params): + flat = [] + for p in params: + g = grads.get(id(p)) if grads is not None else None + if g is None: + g = torch.zeros_like(p, device="cpu", dtype=torch.float64) + else: + g = g.detach().to(device="cpu", dtype=torch.float64) + flat.append(g.reshape(-1)) + return torch.cat(flat) + + +def set_param_requires_grad(blk, value): + for p in blk.allp: + p.requires_grad_(value) + + +def cos(a, b): + return (torch.dot(a, b) / (norm(a) * norm(b) + 1e-20)).item() + + +def rel_diff(a, b): + return (norm(a - b) / (norm(b) + 1e-20)).item() + + +def unit_rand(shape, dev, dtype): + v = torch.randn(shape, device=dev, dtype=dtype) + return v / (norm(v) + 1e-30) + + +@dataclass +class SolveLog: + residuals: list = field(default_factory=list) + infos: list = field(default_factory=list) + iters: list = field(default_factory=list) + + def add(self, rel_res, info, nit): + self.residuals.append(float(rel_res)) + self.infos.append(int(info)) + self.iters.append(int(nit)) + + def summary(self): + if not self.residuals: + return "solve residuals: none" + r = np.asarray(self.residuals, dtype=np.float64) + it = np.asarray(self.iters, dtype=np.int64) + bad = sum(1 for x in self.infos if x != 0) + return ( + f"solve residuals: count={len(r)} min={r.min():.3e} " + f"median={np.median(r):.3e} max={r.max():.3e} " + f"iters median={np.median(it):.0f} max={it.max()} nonzero_info={bad}" + ) + + +class Operators: + def __init__(self, blk, zstar, xin, cfg, mu): + self.blk = blk + self.zstar = zstar.detach() + self.xin = xin.detach() + self.shape = tuple(zstar.shape) + self.n = zstar.numel() + self.dev = zstar.device + self.dtype = zstar.dtype + self.cfg = cfg + self.mu = float(mu) + self.solve_log = SolveLog() + + def f(self, z): + return self.blk.tforce(z, self.xin) + + def jv(self, v): + with torch.enable_grad(): + _, out = torch.autograd.functional.jvp( + self.f, self.zstar, v.contiguous(), create_graph=False, strict=False + ) + return out.detach() + + def jtv(self, v): + with torch.enable_grad(): + z = self.zstar.detach().requires_grad_(True) + fz = self.f(z) + (g,) = torch.autograd.grad(fz, z, grad_outputs=v.contiguous(), create_graph=False, retain_graph=False) + return g.detach() + + def sv(self, v): + jv = self.jv(v) + jtv = self.jtv(v) + return 0.5 * (jv + jtv) + + def av(self, v): + jv = self.jv(v) + jtv = self.jtv(v) + return 0.5 * (jv - jtv) + + def smu(self, v, mu=None): + m = self.mu if mu is None else float(mu) + return self.sv(v) + m * v + + def _from_numpy(self, x): + x = np.asarray(x, dtype=np.float32) + return torch.from_numpy(x).to(device=self.dev, dtype=self.dtype).view(self.shape) + + def _to_numpy(self, x): + return x.detach().reshape(-1).float().cpu().numpy() + + def solve_s(self, rhs, mu=None, tag=""): + m = self.mu if mu is None else float(mu) + b = self._to_numpy(rhs) + counter = {"n": 0} + + def matvec(x_np): + x = self._from_numpy(x_np) + y = self.smu(x, m) + return self._to_numpy(y) + + def cb(_x): + counter["n"] += 1 + + Aop = LinearOperator((self.n, self.n), matvec=matvec, dtype=np.dtype("float32")) + x_np, info = minres(Aop, b, rtol=self.cfg.solve_tol, maxiter=self.cfg.solve_iters, callback=cb, check=False) + x = self._from_numpy(x_np).detach() + rel = (norm(self.smu(x, m) - rhs) / (norm(rhs) + 1e-30)).item() + self.solve_log.add(rel, info, counter["n"]) + if tag: + print(f"solve {tag}: mu={m:.3e} rel_res={rel:.3e} iters={counter['n']} info={info}", flush=True) + return x, rel, info, counter["n"] + + def solve_jt_gmres(self, rhs, tol, maxiter, mu=0.0, tag="adjoint"): + m = float(mu) + b = self._to_numpy(rhs) + counter = {"n": 0} + restart = max(1, min(50, int(maxiter))) + + def matvec(x_np): + x = self._from_numpy(x_np) + y = self.jtv(x) + if m != 0.0: + y = y + m * x + return self._to_numpy(y) + + def cb(_arg): + counter["n"] += 1 + + Aop = LinearOperator((self.n, self.n), matvec=matvec, dtype=np.dtype("float32")) + try: + x_np, info = gmres( + Aop, + b, + rtol=tol, + atol=0.0, + restart=restart, + maxiter=int(maxiter), + callback=cb, + callback_type="legacy", + ) + except TypeError: + x_np, info = gmres(Aop, b, tol=tol, restart=restart, maxiter=int(maxiter), callback=cb) + x = self._from_numpy(x_np).detach() + rel = (norm(self.jtv(x) + m * x - rhs) / (norm(rhs) + 1e-30)).item() + print(f"GMRES {tag}: mu={m:.3e} rel_res={rel:.3e} iters={counter['n']} info={info}", flush=True) + return x, rel, info, counter["n"] + + def t(self, v): + rhs = self.av(v) + x, _, _, _ = self.solve_s(rhs) + return x + + def tt(self, u): + y, _, _, _ = self.solve_s(u) + return -self.av(y) + + +def estimate_trace_s(op, probes): + vals = [] + for i in range(probes): + r = torch.randint(0, 2, op.shape, device=op.dev, dtype=torch.int8).to(op.dtype) + r = r.mul_(2).sub_(1) + sr = op.sv(r) + vals.append((dot(r, sr) / op.n).item()) + print(f"trace probe {i}: tr(S)/n={vals[-1]:+.6e}", flush=True) + return float(np.mean(vals)), float(np.std(vals) if len(vals) > 1 else 0.0) + + +def sensitivity_probe(op, mu): + v = unit_rand(op.shape, op.dev, op.dtype) + rhs = op.av(v) + xb, rb, _, _ = op.solve_s(rhs, mu=mu, tag="sensitivity/base") + rows = [] + for scale in (0.1, 10.0): + ms = max(mu * scale, 0.0) + xa, ra, _, _ = op.solve_s(rhs, mu=ms, tag=f"sensitivity/mu_x{scale:g}") + rel_dx = (norm(xa - xb) / (norm(xb) + 1e-30)).item() + rows.append((scale, ms, rel_dx, ra)) + print( + "solve sensitivity: " + + " ".join(f"mu_x{scale:g}: rel_dx={dx:.3e} rel_res={rr:.3e}" for scale, _, dx, rr in rows), + flush=True, + ) + return rb, rows + + +def power_rho(op, cfg): + best = 0.0 + best_hist = None + for r in range(cfg.rho_restarts): + v = unit_rand(op.shape, op.dev, op.dtype) + hist = [] + for i in range(cfg.rho_iters): + w = op.t(v) + growth = norm(w).item() + rq = (dot(v, w) / (dot(v, v) + 1e-30)).item() + hist.append((growth, rq)) + if growth <= 1e-30 or not math.isfinite(growth): + break + v = (w / growth).detach() + print(f"rho restart={r} iter={i + 1:02d} growth={growth:.6e} rayleigh={rq:+.6e}", flush=True) + if hist and hist[-1][0] > best: + best = hist[-1][0] + best_hist = hist + if best_hist: + trend = " ".join(f"{g:.3g}" for g, _ in best_hist[-min(6, len(best_hist)):]) + rtrend = " ".join(f"{rq:+.3g}" for _, rq in best_hist[-min(6, len(best_hist)):]) + print(f"rho power trend last={trend}", flush=True) + print(f"rho Rayleigh trend last={rtrend}", flush=True) + return best + + +def arnoldi_rho(op, k): + if k <= 0: + return None + q = unit_rand(op.shape, op.dev, op.dtype) + Q = [q] + H = np.zeros((k + 1, k), dtype=np.float64) + m = 0 + for j in range(k): + w = op.t(Q[j]) + for i in range(j + 1): + hij = dot(Q[i], w).item() + H[i, j] = hij + w = w - hij * Q[i] + hnext = norm(w).item() + H[j + 1, j] = hnext + m = j + 1 + print(f"arnoldi iter={j + 1:02d} h_next={hnext:.6e}", flush=True) + if hnext < 1e-12: + break + if j + 1 < k: + Q.append((w / hnext).detach()) + eig = np.linalg.eigvals(H[:m, :m]) + rho = float(np.max(np.abs(eig))) if eig.size else float("nan") + print(f"rho Arnoldi(k={m})={rho:.6e}", flush=True) + return rho + + +def power_sigma(op, cfg): + best = 0.0 + for r in range(cfg.sigma_restarts): + v = unit_rand(op.shape, op.dev, op.dtype) + sigma = 0.0 + for i in range(cfg.sigma_iters): + u = op.t(v) + sigma = norm(u).item() + w = op.tt(u) + wn = norm(w).item() + if wn <= 1e-30 or not math.isfinite(wn): + break + v = (w / wn).detach() + print(f"sigma restart={r} iter={i + 1:02d} sigma={sigma:.6e}", flush=True) + best = max(best, sigma) + return best + + +def ce_state_grad(blk, zstar, y): + with torch.enable_grad(): + z = zstar.detach().requires_grad_(True) + loss = ce(blk, z, y) + (ell,) = torch.autograd.grad(loss, z) + return ell.detach(), float(loss.detach()) + + +def solve_exact_adjoint(op, ell, cfg): + rhs = -ell.detach() + lam, rel, info, nit = op.solve_jt_gmres(rhs, cfg.adjoint_tol, cfg.adjoint_iters, mu=0.0, tag="J^T lambda=-ell") + stalled = (info != 0) or (not math.isfinite(rel)) or (rel > max(10.0 * cfg.adjoint_tol, 1e-4)) + mu_used = 0.0 + if stalled: + mu_used = max(float(cfg.adjoint_mu), 1e-8) + print(f"GMRES stalled; retrying exact-adjoint solve with Tikhonov J^T+muI, mu={mu_used:.3e}", flush=True) + lam, rel, info, nit = op.solve_jt_gmres( + rhs, cfg.adjoint_tol, cfg.adjoint_iters, mu=mu_used, tag="(J^T+muI) lambda=-ell" + ) + return lam.detach(), rel, info, nit, mu_used + + +def exact_transpose_grad(blk, idx, zstar, xin0, lam, params): + for p in blk.allp: + p.requires_grad_(True) + with torch.enable_grad(): + # Value stays at the relaxed clamp xin0, while tok/pos receive the same clamp-gradient path as the trainer. + xin = xin0 + (blk.embed(idx) - blk.embed(idx).detach()) + force = blk.tforce(zstar.detach(), xin) + grads = torch.autograd.grad((force * lam.detach()).sum(), params, allow_unused=True) + return {id(p): g for p, g in zip(params, grads)} + + +def run_ep_step_flat(blk, idx, y, cfg, params, *, beta=None, holo=None, hr=None, t2sel=None, track=None, T2=None): + saved_track = getattr(blk, "track", None) + if track is not None: + blk.track = bool(track) + try: + set_param_requires_grad(blk, True) + # Mirrors lt_ep_train.ep_step: + # (blk, idx, y, T1, T2, eps, beta, jacreg, holo, hr, t1max, res_est, t2sel, corr_every, res_gate, resreg). + grads, ep_res = ep_step( + blk, + idx, + y, + cfg.T1, + cfg.T2 if T2 is None else int(T2), + cfg.eps, + cfg.beta if beta is None else float(beta), + 0.0, + cfg.holo if holo is None else int(holo), + cfg.hr if hr is None else float(hr), + cfg.t1max, + cfg.res_est, + cfg.t2sel if t2sel is None else int(t2sel), + 1, + 0.0, + 0.0, + ) + return flat_grad_by_param_id(grads, params), float(ep_res) + finally: + if track is not None and saved_track is not None: + blk.track = saved_track + + +@torch.no_grad() +def fixed_point_step_abs(blk, zstar, xin, eps): + return (relax(blk, zstar, xin, 1, eps) - zstar).norm().item() + + +def exact_reference_for_batch(blk, idx, y, cfg, label, compute_bptt=True): + print(f"--- exact reference: {label} ---", flush=True) + xin0 = blk.embed(idx).detach() + zstar, steps, step_res, force_res = relax_to_fixed_point(blk, xin0, cfg) + step_abs = fixed_point_step_abs(blk, zstar, xin0, cfg.eps) + print( + f"{label}: z* residual step_abs={step_abs:.6e} step_rel={step_res:.6e} " + f"force_rel={force_res:.6e} relax_steps={steps}", + flush=True, + ) + if step_res > cfg.res_est: + print(f"{label}: WARNING step_res={step_res:.3e} > res_est={cfg.res_est:.3e}", flush=True) + + set_param_requires_grad(blk, False) + op = Operators(blk, zstar, xin0, cfg, mu=0.0) + ell, ce_loss = ce_state_grad(blk, zstar, y) + print(f"{label}: CE(z*)={ce_loss:.6f} ||ell||={norm(ell).item():.6e}", flush=True) + lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg) + print( + f"{label}: adjoint residual={gmres_rel:.3e} iters={gmres_iters} info={gmres_info} " + f"tikhonov_mu={adj_mu:.3e}", + flush=True, + ) + + params = block_param_list(blk) + gt = flat_grad_by_param_id(exact_transpose_grad(blk, idx, zstar, xin0, lam, params), params) + out = { + "idx": idx, + "y": y, + "params": params, + "gt": gt, + "z_step_abs": step_abs, + "z_step_rel": step_res, + "z_force_rel": force_res, + "relax_steps": steps, + "gmres_rel": gmres_rel, + "gmres_info": gmres_info, + "gmres_iters": gmres_iters, + "adj_mu": adj_mu, + "ce_loss": ce_loss, + } + if compute_bptt: + set_param_requires_grad(blk, True) + gB = bptt_step(blk, idx, y, cfg.T1, cfg.eps, 0.0) + out["gBv"] = flat_grad_by_param_id(gB, params) + set_param_requires_grad(blk, True) + return out + + +def draw_seeded_train_batch(cfg, seed): + torch.manual_seed(int(seed)) + return L.get_batch("train", cfg.B, cfg.T) + + +def finite_range(vals): + ok = [float(v) for v in vals if v is not None and math.isfinite(float(v))] + if not ok: + return None + return float(np.mean(ok)), float(np.min(ok)), float(np.max(ok)) + + +def read_multi_batch(rows): + vals = [r.get("cos_ep_t") for r in rows if r.get("ok")] + stats = finite_range(vals) + if stats is None: + return "no successful batches" + mean, mn, mx = stats + spread = mx - mn + if mn > 0.95: + return "consistently aligned across batches" + if mx < 0.80: + return "systematically low across batches" + if spread > 0.20: + return "batch-variance/outlier behavior is material" + return "mostly systematic with moderate batch variance" + + +def read_beta_sweep(rows): + ok = [(r["beta"], r["cos"]) for r in rows if r.get("ok")] + if not ok: + return "no successful beta points" + first_beta, first_cos = ok[0] + last_beta, last_cos = ok[-1] + best_cos = max(c for _, c in ok) + if last_cos > 0.95 and last_cos - first_cos > 0.10: + return f"finite-beta bias likely: cos improves from beta={first_beta:g} to beta={last_beta:g}" + if best_cos < 0.80: + return "cos stays low as beta shrinks: structural/bug more likely than finite-beta bias" + if last_cos > first_cos + 0.05: + return "some finite-beta sensitivity, but not a clean convergence-to-1 result" + return "no strong beta-to-zero improvement" + + +def read_ablation(rows): + ok = [r for r in rows if r.get("ok")] + if not ok: + return "no successful ablations" + full = next((r for r in ok if r["key"] == "full"), None) + best = max(ok, key=lambda r: r["cos"]) + if full is None: + return f"best successful config is {best['label']}" + delta = best["cos"] - full["cos"] + if delta <= 0.05: + return "no ablation materially improves over FULL" + if best["key"] == "track_off": + return "tracking path is suspect: disabling blk.track improved cos" + if best["key"] == "plain": + return "holomorphic/adaptive path is suspect: plain real EP improved cos" + if best["key"] == "fixed_t2": + return "adaptive-T2 selection/tracking is suspect: fixed T2 improved cos" + return f"{best['label']} is the strongest improvement over FULL" + + +def print_diagnostic_summary(multi_rows, beta_rows, ablation_rows): + print("", flush=True) + print("================ DIAGNOSTIC SUMMARY ================", flush=True) + + multi_stats_t = finite_range([r.get("cos_ep_t") for r in multi_rows if r.get("ok")]) + multi_stats_b = finite_range([r.get("cos_ep_b") for r in multi_rows if r.get("ok")]) + if multi_stats_t is None: + print("Multi-batch: no successful batches", flush=True) + else: + mean, mn, mx = multi_stats_t + print(f"Multi-batch: mean cos(g_EP,g_transpose)={mean:+.6f} range=[{mn:+.6f}, {mx:+.6f}]", flush=True) + if multi_stats_b is not None: + mean, mn, mx = multi_stats_b + print(f"Multi-batch: mean cos(g_EP,g_BPTT)={mean:+.6f} range=[{mn:+.6f}, {mx:+.6f}]", flush=True) + print(f"Multi-batch read: {read_multi_batch(multi_rows)}", flush=True) + + print("Beta sweep (beta | cos(g_EP,g_transpose)):", flush=True) + if beta_rows: + for row in beta_rows: + if row.get("ok"): + print(f" {row['beta']:<8g} | {row['cos']:+.6f}", flush=True) + else: + print(f" {row.get('beta', 'n/a')!s:<8} | failed: {row.get('error')}", flush=True) + else: + print(" none", flush=True) + print(f"Beta sweep read: {read_beta_sweep(beta_rows)}", flush=True) + + print("Ablation (config | cos(g_EP,g_transpose)):", flush=True) + if ablation_rows: + for row in ablation_rows: + if row.get("ok"): + print(f" {row['label']} | {row['cos']:+.6f}", flush=True) + else: + print(f" {row.get('label', 'unknown')} | failed: {row.get('error')}", flush=True) + else: + print(" none", flush=True) + print(f"Ablation read: {read_ablation(ablation_rows)}", flush=True) + print("============== END DIAGNOSTIC SUMMARY ==============", flush=True) + + +def run_diagnostics(blk, cfg, ck): + print("=== DIAGNOSTIC MODE ===", flush=True) + print(f"# ckpt step {ck.get('step')} best {ck.get('best')}", flush=True) + print( + "ep_step paths: holo=2,t2sel>0,track=True -> holo_a_track; " + "holo=2,t2sel>0,track=False -> holo_a_select2; holo>0,t2sel=0 -> holo_a; holo=0 -> plain EP", + flush=True, + ) + print("gradient comparison scope: blk.block parameters; readout Wh is excluded", flush=True) + + multi_rows = [] + beta_rows = [] + ablation_rows = [] + seed1000_ref = None + + print("=== DIAGNOSTIC 1: MULTI-BATCH ===", flush=True) + for i in range(6): + seed = 1000 + i + label = f"diag1 batch={i} seed={seed}" + try: + idx, y = draw_seeded_train_batch(cfg, seed) + ref = exact_reference_for_batch(blk, idx, y, cfg, label, compute_bptt=True) + torch.manual_seed(seed) + gEPv, ep_res = run_ep_step_flat(blk, idx, y, cfg, ref["params"]) + row = { + "ok": True, + "batch": i, + "seed": seed, + "cos_ep_t": cos(gEPv, ref["gt"]), + "cos_ep_b": cos(gEPv, ref["gBv"]), + "cos_t_b": cos(ref["gt"], ref["gBv"]), + "z_step_abs": ref["z_step_abs"], + "z_step_rel": ref["z_step_rel"], + "z_force_rel": ref["z_force_rel"], + "ep_res": ep_res, + } + multi_rows.append(row) + print( + f"{label}: cos(g_EP,g_transpose)={row['cos_ep_t']:+.6f} " + f"cos(g_EP,g_BPTT)={row['cos_ep_b']:+.6f} " + f"cos(g_transpose,g_BPTT)={row['cos_t_b']:+.6f} " + f"z_res_abs={row['z_step_abs']:.6e} z_res_rel={row['z_step_rel']:.6e} ep_res={ep_res:.6e}", + flush=True, + ) + if seed == 1000: + seed1000_ref = ref + except Exception as err: + row = {"ok": False, "batch": i, "seed": seed, "error": repr(err)} + multi_rows.append(row) + print(f"{label} failed: {err!r}", flush=True) + + multi_stats_t = finite_range([r.get("cos_ep_t") for r in multi_rows if r.get("ok")]) + multi_stats_b = finite_range([r.get("cos_ep_b") for r in multi_rows if r.get("ok")]) + if multi_stats_t is not None and multi_stats_b is not None: + mt, mint, maxt = multi_stats_t + mb, minb, maxb = multi_stats_b + print( + f"DIAG1 aggregate: cos(g_EP,g_transpose) mean={mt:+.6f} min={mint:+.6f} max={maxt:+.6f}; " + f"cos(g_EP,g_BPTT) mean={mb:+.6f} min={minb:+.6f} max={maxb:+.6f}", + flush=True, + ) + + if seed1000_ref is None: + try: + idx, y = draw_seeded_train_batch(cfg, 1000) + seed1000_ref = exact_reference_for_batch(blk, idx, y, cfg, "diag seed=1000 fallback", compute_bptt=True) + except Exception as err: + print(f"seed=1000 reference failed; beta sweep and ablation cannot run: {err!r}", flush=True) + + print("=== DIAGNOSTIC 2: BETA SWEEP ===", flush=True) + if seed1000_ref is not None: + for beta in [0.04, 0.02, 0.01, 0.005, 0.002]: + try: + torch.manual_seed(1000) + gEPv, ep_res = run_ep_step_flat( + blk, + seed1000_ref["idx"], + seed1000_ref["y"], + cfg, + seed1000_ref["params"], + beta=beta, + hr=beta, + ) + row = {"ok": True, "beta": beta, "cos": cos(gEPv, seed1000_ref["gt"]), "ep_res": ep_res} + beta_rows.append(row) + print(f"beta={beta:g} hr={beta:g}: cos(g_EP,g_transpose)={row['cos']:+.6f} ep_res={ep_res:.6e}", flush=True) + except Exception as err: + beta_rows.append({"ok": False, "beta": beta, "error": repr(err)}) + print(f"beta={beta:g} failed: {err!r}", flush=True) + else: + print("DIAG2 skipped: seed=1000 reference unavailable", flush=True) + + print("=== DIAGNOSTIC 3: COMPONENT ABLATION ===", flush=True) + if seed1000_ref is not None: + ablations = [ + { + "key": "full", + "label": "FULL holo=2 track=True t2sel=40", + "kwargs": {"holo": 2, "track": True, "t2sel": 40, "hr": cfg.hr, "beta": cfg.beta}, + }, + { + "key": "track_off", + "label": "holo=2 track=False t2sel=40", + "kwargs": {"holo": 2, "track": False, "t2sel": 40, "hr": cfg.hr, "beta": cfg.beta}, + }, + { + "key": "plain", + "label": "plain EP holo=0 track ignored t2sel=0", + "kwargs": {"holo": 0, "track": getattr(blk, "track", False), "t2sel": 0, "hr": cfg.hr, "beta": cfg.beta}, + }, + { + "key": "fixed_t2", + "label": f"holo=2 track=True t2sel=0 fixed T2={cfg.T2}", + "kwargs": {"holo": 2, "track": True, "t2sel": 0, "hr": cfg.hr, "beta": cfg.beta, "T2": cfg.T2}, + }, + ] + for item in ablations: + try: + torch.manual_seed(1000) + gEPv, ep_res = run_ep_step_flat( + blk, + seed1000_ref["idx"], + seed1000_ref["y"], + cfg, + seed1000_ref["params"], + **item["kwargs"], + ) + row = { + "ok": True, + "key": item["key"], + "label": item["label"], + "cos": cos(gEPv, seed1000_ref["gt"]), + "ep_res": ep_res, + } + ablation_rows.append(row) + print(f"{item['label']}: cos(g_EP,g_transpose)={row['cos']:+.6f} ep_res={ep_res:.6e}", flush=True) + except Exception as err: + row = {"ok": False, "key": item["key"], "label": item["label"], "error": repr(err)} + ablation_rows.append(row) + print(f"config {item['label']} failed: {err!r}", flush=True) + else: + print("DIAG3 skipped: seed=1000 reference unavailable", flush=True) + + print_diagnostic_summary(multi_rows, beta_rows, ablation_rows) + + +def compare_exact_adjoint(blk, idx, y, zstar, xin0, op, cfg): + print("=== exact-adjoint gradient comparison ===", flush=True) + ell, ce_loss = ce_state_grad(blk, zstar, y) + print(f"CE(z*)={ce_loss:.6f} ||ell||={norm(ell).item():.6e}", flush=True) + lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg) + print( + f"adjoint solve summary: residual={gmres_rel:.3e} iters={gmres_iters} info={gmres_info} " + f"tikhonov_mu={adj_mu:.3e}", + flush=True, + ) + + params = block_param_list(blk) + gt = flat_grad_by_param_id(exact_transpose_grad(blk, idx, zstar, xin0, lam, params), params) + print("gradient comparison scope: blk.block parameters; readout Wh is excluded", flush=True) + + gB = bptt_step(blk, idx, y, cfg.T1, cfg.eps, 0.0) + gEP, ep_res = ep_step( + blk, + idx, + y, + cfg.T1, + cfg.T2, + cfg.eps, + cfg.beta, + 0.0, + cfg.holo, + cfg.hr, + cfg.t1max, + cfg.res_est, + cfg.t2sel, + 1, + 0.0, + ) + gBv = flat_grad_by_param_id(gB, params) + gEPv = flat_grad_by_param_id(gEP, params) + + print(f"EP estimator free-phase residual from ep_step={ep_res:.6e}", flush=True) + print(f"||g_transpose||={norm(gt).item():.6e} ||g_BPTT||={norm(gBv).item():.6e} ||g_EP||={norm(gEPv).item():.6e}", flush=True) + c_t_b = cos(gt, gBv) + d_t_b = rel_diff(gt, gBv) + c_ep_t = cos(gEPv, gt) + d_ep_t = rel_diff(gEPv, gt) + c_ep_b = cos(gEPv, gBv) + d_ep_b = rel_diff(gEPv, gBv) + print(f"cos(g_transpose, g_BPTT)={c_t_b:+.6f} ||g_transpose-g_BPTT||/||g_BPTT||={d_t_b:.6e}", flush=True) + print(f"cos(g_EP, g_transpose)={c_ep_t:+.6f} ||g_EP-g_transpose||/||g_transpose||={d_ep_t:.6e}", flush=True) + print(f"cos(g_EP, g_BPTT)={c_ep_b:+.6f} ||g_EP-g_BPTT||/||g_BPTT||={d_ep_b:.6e}", flush=True) + print("interpretation:", flush=True) + print(" cos(g_transpose,g_BPTT)~1 AND cos(g_EP,g_transpose)~1 -> our EP IS the exact adjoint; failure is convergence/contraction", flush=True) + print(" cos(g_transpose,g_BPTT)~1 AND cos(g_EP,g_transpose)<1 -> exact adjoint works, our EP falls short -> implement exact/dyadic", flush=True) + print(" cos(g_transpose,g_BPTT)<1 -> even exact adjoint != BPTT -> finite-time/convergence, not the adjoint", flush=True) + + +def main(): + cfg = parse_args() + cfg.ckpt = resolve_ckpt_path(cfg.ckpt) + require_cuda_if_requested(cfg.device) + dev = torch.device("cuda:0" if cfg.device == "cuda" else "cpu") + torch.backends.cuda.matmul.allow_tf32 = bool(cfg.tf32) + torch.backends.cudnn.allow_tf32 = bool(cfg.tf32) + print(f"# asym_probe device={dev} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r}", flush=True) + print( + f"# ckpt={cfg.ckpt} B={cfg.B} T={cfg.T} C={cfg.C} H={cfg.H} Mm={cfg.Mm} " + f"attn_mode=thick qknorm=True gelu={cfg.gelu}", + flush=True, + ) + blk, ck = build_block(cfg, dev) + if cfg.diag: + run_diagnostics(blk, cfg, ck) + return + + idx, y = L.get_batch("train", cfg.B, cfg.T) + xin0 = blk.embed(idx).detach() + zstar, steps, step_res, force_res = relax_to_fixed_point(blk, xin0, cfg) + print(f"# ckpt step {ck.get('step')} best {ck.get('best')}", flush=True) + print(f"z* residual: step_res={step_res:.6e} force_res={force_res:.6e} relax_steps={steps}", flush=True) + if step_res > cfg.res_est: + print(f"WARNING: fixed-point target not reached: step_res={step_res:.3e} > {cfg.res_est:.3e}", flush=True) + if step_res > 1e-3 or force_res > 1e-3: + print("WARNING: relaxed z* residual exceeds 1e-3; do not trust exact-adjoint solves until convergence improves", flush=True) + + # Freeze parameters for state Jacobian products. tforce is out-of-place; each + # VJP re-leafs z* to avoid stale graphs, and xin0 is held detached/fixed. + set_param_requires_grad(blk, False) + print("autograd note: using blk.tforce directly; no in-place tforce ops patched; z* is re-leafed per VJP/JVP", flush=True) + + op0 = Operators(blk, zstar, xin0, cfg, mu=0.0) + if cfg.skiprho: + compare_exact_adjoint(blk, idx, y, zstar, xin0, op0, cfg) + return + + tr_mean, tr_std = estimate_trace_s(op0, cfg.trace_probes) + if cfg.mu >= 0: + mu = float(cfg.mu) + else: + mu = cfg.mu_scale * max(abs(tr_mean), 1e-12) + print(f"trace(S)/n estimate={tr_mean:+.6e} std={tr_std:.3e}", flush=True) + print(f"mu used={mu:.6e} (mu_scale={cfg.mu_scale:g}, solve operator S+muI)", flush=True) + + op = Operators(blk, zstar, xin0, cfg, mu=mu) + sensitivity_probe(op, mu) + t0 = time.time() + rho_power = power_rho(op, cfg) + rho_arnoldi = arnoldi_rho(op, cfg.arnoldi_k) + sigma = power_sigma(op, cfg) + elapsed = time.time() - t0 + rho = max(rho_power, rho_arnoldi if rho_arnoldi is not None else 0.0) + print(op.solve_log.summary(), flush=True) + print("non-normal note: power iteration reports dominant growth; Rayleigh trend may be small/oscillatory for skew modes", flush=True) + print(f"rho(S^-1 A)={rho:.6e} power={rho_power:.6e} arnoldi={rho_arnoldi if rho_arnoldi is not None else float('nan'):.6e}", flush=True) + print(f"||S^-1 A||_2={sigma:.6e}", flush=True) + print(f"elapsed_operator_seconds={elapsed:.1f}", flush=True) + verdict = "higher-order AEP viable" if rho < 1.0 else "higher-order AEP not viable" + print(f"VERDICT: rho {'<' if rho < 1.0 else '>='} 1 => {verdict}", flush=True) + compare_exact_adjoint(blk, idx, y, zstar, xin0, op, cfg) + + +if __name__ == "__main__": + main() |
