summaryrefslogtreecommitdiff
path: root/diag/lyap.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/lyap.py')
-rw-r--r--diag/lyap.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/diag/lyap.py b/diag/lyap.py
new file mode 100644
index 0000000..93b90bf
--- /dev/null
+++ b/diag/lyap.py
@@ -0,0 +1,83 @@
+"""LE diagnostic for the recursive (TRM-ish) GNN — ports the flossing finding to graphs.
+
+Per-graph top Lyapunov exponent lambda1 of the recursion z <- block(z+h0), via Benettin
+power-iteration on a single tangent vector (JVP + renormalize, accumulate log-growth) over
+the model's n_sup*T recursion steps. Bucket graphs by success/failure (rounded ring counts
+exact) and compare lambda1 distributions + AUROC(fail | lambda1) — mirroring
+plot_trm_lyap_hist.py. Hypothesis: failed graphs are MORE chaotic (higher lambda1).
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_full_..._s0.pt
+"""
+import argparse
+import numpy as np
+import torch
+from diag.train_rec import RecGIN
+from diag.train_cycle import prepare
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+
+def build(ck, dev):
+ c = ck['cfg']
+ m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode']).to(dev)
+ m.load_state_dict(ck['state']); m.eval()
+ return m, c
+
+
+def lyap1(model, x, ei, n_steps, dev, seed=0):
+ g = torch.Generator(device=dev).manual_seed(seed)
+ h0 = model.emb(x).detach()
+ z = torch.zeros_like(h0)
+ v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
+ def step_fn(zz):
+ return model.block(zz + h0, ei)
+ lam = 0.0
+ for _ in range(n_steps):
+ z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
+ z = z_next.detach()
+ nv = Jv.norm()
+ lam += torch.log(nv + 1e-12).item()
+ v = (Jv / (nv + 1e-12)).detach()
+ return lam / n_steps
+
+
+@torch.no_grad()
+def predict(model, x, ei, dev):
+ batch = torch.zeros(x.size(0), dtype=torch.long, device=dev)
+ preds, _ = model(x, ei, batch, noise=False)
+ return preds[-1].view(-1)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--ckpt', required=True)
+ ap.add_argument('--n_graphs', type=int, default=300)
+ args = ap.parse_args()
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ ck = torch.load(args.ckpt, weights_only=False)
+ model, cfg = build(ck, dev)
+ ymu, ysd = ck['ymu'].to(dev), ck['ysd'].to(dev)
+ te = prepare('test')
+ n_steps = cfg['n_sup'] * cfg['T']
+
+ lams, fails = [], []
+ for i, r in enumerate(te[:args.n_graphs]):
+ x = r['x'].to(dev); ei = r['edge_index'].to(dev)
+ p = predict(model, x, ei, dev) * ysd + ymu # raw [2]
+ y = r['y'].to(dev) # raw [2]
+ fails.append(int(not torch.all(p.round() == y.round()).item()))
+ lams.append(lyap1(model, x, ei, n_steps, dev, seed=i))
+ lams, fails = np.array(lams), np.array(fails)
+ s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ print(f"[{cfg['grad_mode']}] n={len(lams)} fail_rate={fails.mean():.2f} | "
+ f"lambda1 SUCC mean {s.mean():+.4f} std {s.std():.4f} (n={len(s)}) | "
+ f"FAIL mean {f.mean():+.4f} std {f.std():.4f} (n={len(f)}) | "
+ f"sep(fail-succ)={f.mean()-s.mean() if len(s) and len(f) else float('nan'):+.4f} | "
+ f"AUROC(fail|lambda1)={auc:.3f} | mean_lambda1={lams.mean():+.4f}")
+
+
+if __name__ == "__main__":
+ main()