summaryrefslogtreecommitdiff
path: root/diag/lyap.py
blob: e67656838e7e8a48896deeedfb7d332c5c312cae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""LE diagnostic for the recursive (TRM-ish) GNN — ports the flossing finding to graphs.

Per-graph top Lyapunov exponent lambda1 of the edge-free recursion z <- block(z, ctx), via Benettin
power-iteration on a single tangent vector (JVP + renormalize, accumulate log-growth) over
the model's n_sup*T recursion steps. Bucket graphs by success/failure (rounded ring counts
exact) and compare lambda1 distributions + AUROC(fail | lambda1) — mirroring
plot_trm_lyap_hist.py. Hypothesis: failed graphs are MORE chaotic (higher lambda1).

Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_rrog_full_..._s0.pt
"""
import argparse
import numpy as np
import torch
from diag.train_rec import RecGIN
from diag.train_cycle import prepare
try:
    from sklearn.metrics import roc_auc_score
except Exception:
    roc_auc_score = None


def build(ck, dev):
    c = ck['cfg']
    m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode'],
               agg_layers=c.get('agg_layers', 1), compute_layers=c.get('compute_layers', 2)).to(dev)
    m.load_state_dict(ck['state']); m.eval()
    return m, c


def lyap1(model, x, ei, n_steps, dev, seed=0):
    g = torch.Generator(device=dev).manual_seed(seed)
    ctx = model.aggregate(x, ei).detach()
    z = ctx.detach()
    v = torch.randn(ctx.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
    def step_fn(zz):
        return model.block(zz, ctx)
    lam = 0.0
    for _ in range(n_steps):
        z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
        z = z_next.detach()
        nv = Jv.norm()
        lam += torch.log(nv + 1e-12).item()
        v = (Jv / (nv + 1e-12)).detach()
    return lam / n_steps


@torch.no_grad()
def predict(model, x, ei, dev):
    batch = torch.zeros(x.size(0), dtype=torch.long, device=dev)
    preds, _ = model(x, ei, batch, noise=False)
    return preds[-1].view(-1)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--ckpt', required=True)
    ap.add_argument('--n_graphs', type=int, default=300)
    args = ap.parse_args()
    dev = 'cuda' if torch.cuda.is_available() else 'cpu'
    ck = torch.load(args.ckpt, weights_only=False)
    model, cfg = build(ck, dev)
    ymu, ysd = ck['ymu'].to(dev), ck['ysd'].to(dev)
    te = prepare('test')
    n_steps = cfg['n_sup'] * cfg['T']

    lams, fails = [], []
    for i, r in enumerate(te[:args.n_graphs]):
        x = r['x'].to(dev); ei = r['edge_index'].to(dev)
        p = predict(model, x, ei, dev) * ysd + ymu        # raw [2]
        y = r['y'].to(dev)                                # raw [2]
        fails.append(int(not torch.all(p.round() == y.round()).item()))
        lams.append(lyap1(model, x, ei, n_steps, dev, seed=i))
    lams, fails = np.array(lams), np.array(fails)
    s, f = lams[fails == 0], lams[fails == 1]
    auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
    sm, ss = (s.mean(), s.std()) if len(s) else (float('nan'), float('nan'))
    fm, fs = (f.mean(), f.std()) if len(f) else (float('nan'), float('nan'))
    sep = fm - sm if len(s) and len(f) else float('nan')
    print(f"[{cfg['grad_mode']}] n={len(lams)} fail_rate={fails.mean():.2f} | "
          f"lambda1 SUCC mean {sm:+.4f} std {ss:.4f} (n={len(s)}) | "
          f"FAIL mean {fm:+.4f} std {fs:.4f} (n={len(f)}) | "
          f"sep(fail-succ)={sep:+.4f} | "
          f"AUROC(fail|lambda1)={auc:.3f} | mean_lambda1={lams.mean():+.4f}")


if __name__ == "__main__":
    main()