summaryrefslogtreecommitdiff
path: root/diag/ptrm_color.py
blob: b24097f4662ca14610fa92e1d8b8b9e5f8cd67ed (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Step-2(a): PTRM test-time noise + lambda-based selection on a trained coloring model
(any backbone feature set via cfg pe). Writes a JSON per ckpt for multi-seed aggregation.

deterministic / pass@K (conflict-min, ground truth) / lambda-select (min lambda1) / random.

Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_rrog_trm_gin_full_...pt
"""
import argparse, json, os
import numpy as np
import torch
from diag.train_color import RecGINColor, make_split, featurize
try:
    from sklearn.metrics import roc_auc_score
except Exception:
    roc_auc_score = None
OUT = '/home/yurenh2/rrog/runs'


def rollout(model, xin, ei, sigma, n_sup, T, dev, seed):
    gen = torch.Generator(device=dev).manual_seed(seed)
    ctx = model.aggregate(xin, ei)
    y, z = ctx, torch.zeros_like(ctx)
    state = torch.cat([y, z], dim=-1)
    v = torch.randn(state.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12)
    def step(ss):
        yy, zz = ss.chunk(2, dim=-1)
        yy, zz = model.recurse(yy, zz, ctx, noise=False)
        return torch.cat([yy, zz], dim=-1)
    lam = 0.0
    for _ in range(n_sup * T):
        state_det, Jv = torch.autograd.functional.jvp(step, state, v)
        nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
        state = state_det.detach()
        if sigma > 0:
            state = state + sigma * torch.randn(state.shape, generator=gen, device=dev)
    lam /= max(n_sup * T, 1)
    y, _ = state.chunk(2, dim=-1)
    col = model.head(y).argmax(-1)
    conf = (col[ei[0]] == col[ei[1]]).sum().item() // 2
    return conf, lam


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--ckpt', required=True)
    ap.add_argument('--K', type=int, default=16)
    ap.add_argument('--n_graphs', type=int, default=150)
    ap.add_argument('--sigmas', type=float, nargs='+', default=[0.05, 0.1, 0.2, 0.4])
    args = ap.parse_args()
    dev = 'cuda' if torch.cuda.is_available() else 'cpu'
    ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg']
    deg = torch.tensor(c['deg']) if c.get('deg') else None
    model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'],
                        grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg,
                        agg_layers=c.get('agg_layers', 1),
                        compute_layers=c.get('compute_layers', 2),
                        compute=(c.get('compute') if c.get('compute') == 'trm' else 'trm'),
                        attn_heads=c.get('attn_heads', 4)).to(dev)
    model.load_state_dict(ck['state']); model.eval()
    nsup, T = c['n_sup'], c['T']
    te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), c.get('pe', 'none'), c.get('rwse_k', 16))
    te = te[:args.n_graphs]; n = len(te)

    det = sum(rollout(model, r['xin'].to(dev), r['edge_index'].to(dev), 0.0, nsup, T, dev, 0)[0] == 0
              for r in te) / n
    out = {'conv': c.get('conv', 'gin'), 'pe': c.get('pe', 'none'), 'seed': c.get('seed'),
           'arch': c.get('arch', 'legacy'),
           'grad_mode': c['grad_mode'], 'contract': c.get('contract', False), 'det': det, 'sigmas': {}}
    print(f"[pe={out['pe']} s{out['seed']}] deterministic solve_rate = {det:.3f}  (n={n}, K={args.K})")
    print(f"{'sigma':>6} {'pass@K':>8} {'lam-sel':>8} {'random':>8} {'perRoll':>8} {'AUROC(s|-lam)':>14}")
    for sigma in args.sigmas:
        passk = lamsel = rand = 0
        L, S = [], []
        for gi, r in enumerate(te):
            xin = r['xin'].to(dev); ei = r['edge_index'].to(dev)
            res = [rollout(model, xin, ei, sigma, nsup, T, dev, 1000 * gi + j) for j in range(args.K)]
            confs = np.array([c0 for c0, _ in res]); lams = np.array([l for _, l in res])
            solved = confs == 0
            passk += int(solved.any()); lamsel += int(solved[lams.argmin()]); rand += int(solved[0])
            L += lams.tolist(); S += solved.tolist()
        L, S = np.array(L), np.array(S)
        auc = (roc_auc_score(S.astype(int), -L) if roc_auc_score and S.any() and (~S).any() else float('nan'))
        out['sigmas'][str(sigma)] = {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n,
                                     'perRoll': float(S.mean()), 'auroc': float(auc)}
        print(f"{sigma:>6} {passk/n:>8.3f} {lamsel/n:>8.3f} {rand/n:>8.3f} {S.mean():>8.3f} {auc:>14.3f}")

    base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
    with open(os.path.join(OUT, f"ptrm_{base}.json"), 'w') as f:
        json.dump(out, f, indent=2)
    print("  wrote", os.path.join(OUT, f"ptrm_{base}.json"))


if __name__ == "__main__":
    main()