summaryrefslogtreecommitdiff
path: root/diag/lyap.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/lyap.py')
-rw-r--r--diag/lyap.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/diag/lyap.py b/diag/lyap.py
index 93b90bf..e676568 100644
--- a/diag/lyap.py
+++ b/diag/lyap.py
@@ -1,12 +1,12 @@
"""LE diagnostic for the recursive (TRM-ish) GNN — ports the flossing finding to graphs.
-Per-graph top Lyapunov exponent lambda1 of the recursion z <- block(z+h0), via Benettin
+Per-graph top Lyapunov exponent lambda1 of the edge-free recursion z <- block(z, ctx), via Benettin
power-iteration on a single tangent vector (JVP + renormalize, accumulate log-growth) over
the model's n_sup*T recursion steps. Bucket graphs by success/failure (rounded ring counts
exact) and compare lambda1 distributions + AUROC(fail | lambda1) — mirroring
plot_trm_lyap_hist.py. Hypothesis: failed graphs are MORE chaotic (higher lambda1).
-Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_full_..._s0.pt
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_rrog_full_..._s0.pt
"""
import argparse
import numpy as np
@@ -21,18 +21,19 @@ except Exception:
def build(ck, dev):
c = ck['cfg']
- m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode']).to(dev)
+ m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode'],
+ agg_layers=c.get('agg_layers', 1), compute_layers=c.get('compute_layers', 2)).to(dev)
m.load_state_dict(ck['state']); m.eval()
return m, c
def lyap1(model, x, ei, n_steps, dev, seed=0):
g = torch.Generator(device=dev).manual_seed(seed)
- h0 = model.emb(x).detach()
- z = torch.zeros_like(h0)
- v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
+ ctx = model.aggregate(x, ei).detach()
+ z = ctx.detach()
+ v = torch.randn(ctx.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
def step_fn(zz):
- return model.block(zz + h0, ei)
+ return model.block(zz, ctx)
lam = 0.0
for _ in range(n_steps):
z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
@@ -72,10 +73,13 @@ def main():
lams, fails = np.array(lams), np.array(fails)
s, f = lams[fails == 0], lams[fails == 1]
auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ sm, ss = (s.mean(), s.std()) if len(s) else (float('nan'), float('nan'))
+ fm, fs = (f.mean(), f.std()) if len(f) else (float('nan'), float('nan'))
+ sep = fm - sm if len(s) and len(f) else float('nan')
print(f"[{cfg['grad_mode']}] n={len(lams)} fail_rate={fails.mean():.2f} | "
- f"lambda1 SUCC mean {s.mean():+.4f} std {s.std():.4f} (n={len(s)}) | "
- f"FAIL mean {f.mean():+.4f} std {f.std():.4f} (n={len(f)}) | "
- f"sep(fail-succ)={f.mean()-s.mean() if len(s) and len(f) else float('nan'):+.4f} | "
+ f"lambda1 SUCC mean {sm:+.4f} std {ss:.4f} (n={len(s)}) | "
+ f"FAIL mean {fm:+.4f} std {fs:.4f} (n={len(f)}) | "
+ f"sep(fail-succ)={sep:+.4f} | "
f"AUROC(fail|lambda1)={auc:.3f} | mean_lambda1={lams.mean():+.4f}")