diff options
Diffstat (limited to 'diag/ptrm_color.py')
| -rw-r--r-- | diag/ptrm_color.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/diag/ptrm_color.py b/diag/ptrm_color.py new file mode 100644 index 0000000..4004297 --- /dev/null +++ b/diag/ptrm_color.py @@ -0,0 +1,85 @@ +"""Step-2(a): PTRM test-time noise + lambda-based selection on a trained coloring model +(any backbone feature set via cfg pe). Writes a JSON per ckpt for multi-seed aggregation. + +deterministic / pass@K (conflict-min, ground truth) / lambda-select (min lambda1) / random. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_...pt +""" +import argparse, json, os +import numpy as np +import torch +from diag.train_color import RecGINColor, make_split, featurize +try: + from sklearn.metrics import roc_auc_score +except Exception: + roc_auc_score = None +OUT = '/home/yurenh2/rrog/runs' + + +def rollout(model, xin, ei, sigma, n_sup, T, dev, seed): + gen = torch.Generator(device=dev).manual_seed(seed) + h0 = model.lin_in(xin) + z = torch.zeros_like(h0) + v = torch.randn(h0.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12) + def step(zz): + return model.block(zz + h0, ei) + lam = 0.0 + for _ in range(n_sup * T): + z_det, Jv = torch.autograd.functional.jvp(step, z, v) + nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach() + z = z_det.detach() + if sigma > 0: + z = z + sigma * torch.randn(z.shape, generator=gen, device=dev) + lam /= (n_sup * T) + col = model.head(z).argmax(-1) + conf = (col[ei[0]] == col[ei[1]]).sum().item() // 2 + return conf, lam + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--ckpt', required=True) + ap.add_argument('--K', type=int, default=16) + ap.add_argument('--n_graphs', type=int, default=150) + ap.add_argument('--sigmas', type=float, nargs='+', default=[0.05, 0.1, 0.2, 0.4]) + args = ap.parse_args() + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg'] + deg = torch.tensor(c['deg']) if c.get('deg') else None + model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'], + grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev) + model.load_state_dict(ck['state']); model.eval() + nsup, T = c['n_sup'], c['T'] + te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), c.get('pe', 'none'), c.get('rwse_k', 16)) + te = te[:args.n_graphs]; n = len(te) + + det = sum(rollout(model, r['xin'].to(dev), r['edge_index'].to(dev), 0.0, nsup, T, dev, 0)[0] == 0 + for r in te) / n + out = {'conv': c.get('conv', 'gin'), 'pe': c.get('pe', 'none'), 'seed': c.get('seed'), + 'grad_mode': c['grad_mode'], 'contract': c.get('contract', False), 'det': det, 'sigmas': {}} + print(f"[pe={out['pe']} s{out['seed']}] deterministic solve_rate = {det:.3f} (n={n}, K={args.K})") + print(f"{'sigma':>6} {'pass@K':>8} {'lam-sel':>8} {'random':>8} {'perRoll':>8} {'AUROC(s|-lam)':>14}") + for sigma in args.sigmas: + passk = lamsel = rand = 0 + L, S = [], [] + for gi, r in enumerate(te): + xin = r['xin'].to(dev); ei = r['edge_index'].to(dev) + res = [rollout(model, xin, ei, sigma, nsup, T, dev, 1000 * gi + j) for j in range(args.K)] + confs = np.array([c0 for c0, _ in res]); lams = np.array([l for _, l in res]) + solved = confs == 0 + passk += int(solved.any()); lamsel += int(solved[lams.argmin()]); rand += int(solved[0]) + L += lams.tolist(); S += solved.tolist() + L, S = np.array(L), np.array(S) + auc = (roc_auc_score(S.astype(int), -L) if roc_auc_score and S.any() and (~S).any() else float('nan')) + out['sigmas'][str(sigma)] = {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n, + 'perRoll': float(S.mean()), 'auroc': float(auc)} + print(f"{sigma:>6} {passk/n:>8.3f} {lamsel/n:>8.3f} {rand/n:>8.3f} {S.mean():>8.3f} {auc:>14.3f}") + + base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '') + with open(os.path.join(OUT, f"ptrm_{base}.json"), 'w') as f: + json.dump(out, f, indent=2) + print(" wrote", os.path.join(OUT, f"ptrm_{base}.json")) + + +if __name__ == "__main__": + main() |
