diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-06-29 12:04:47 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-06-29 12:04:47 -0500 |
| commit | c54ddb88b532be28ca3096e21de405d90163ecfa (patch) | |
| tree | 3270ec9269dbee14ea915963f0d28e933303d5a7 /diag | |
| parent | d12722525fc010a3910b5152c72654a2ade5eac4 (diff) | |
Package full RRoG GNN project
Diffstat (limited to 'diag')
| -rw-r--r-- | diag/aggregate.py | 5 | ||||
| -rw-r--r-- | diag/cin_color.py | 3 | ||||
| -rw-r--r-- | diag/esan_color.py | 3 | ||||
| -rw-r--r-- | diag/eval_rec_ttc.py | 113 | ||||
| -rw-r--r-- | diag/lyap.py | 24 | ||||
| -rw-r--r-- | diag/ptrm_color.py | 33 | ||||
| -rw-r--r-- | diag/run_archA.sh | 5 | ||||
| -rw-r--r-- | diag/run_archB.sh | 8 | ||||
| -rw-r--r-- | diag/run_color.sh | 13 | ||||
| -rw-r--r-- | diag/run_le.sh | 6 | ||||
| -rw-r--r-- | diag/run_pe.sh | 11 | ||||
| -rw-r--r-- | diag/run_pe2.sh | 6 | ||||
| -rw-r--r-- | diag/run_pe3.sh | 8 | ||||
| -rw-r--r-- | diag/run_pna.sh | 3 | ||||
| -rw-r--r-- | diag/run_rec.sh | 13 | ||||
| -rw-r--r-- | diag/run_seeds.sh | 7 | ||||
| -rw-r--r-- | diag/train_color.py | 142 | ||||
| -rw-r--r-- | diag/train_cycle.py | 11 | ||||
| -rw-r--r-- | diag/train_rec.py | 423 |
19 files changed, 663 insertions, 174 deletions
diff --git a/diag/aggregate.py b/diag/aggregate.py index b0f737a..4ded30f 100644 --- a/diag/aggregate.py +++ b/diag/aggregate.py @@ -1,4 +1,4 @@ -"""Aggregate multi-seed coloring results -> mean+/-std per (grad_mode, pe, contract).""" +"""Aggregate multi-seed coloring results -> mean+/-std per architecture/config.""" import glob, json import numpy as np from collections import defaultdict @@ -22,7 +22,8 @@ def load(pat): def key(d): - return (d.get('conv', 'gin'), d.get('pe'), d.get('grad_mode'), 'ctr' if d.get('contract') else '-') + return (d.get('arch', 'legacy'), d.get('conv', 'gin'), d.get('pe'), + d.get('grad_mode'), 'ctr' if d.get('contract') else '-') solve, le, ml = defaultdict(list), defaultdict(list), defaultdict(list) diff --git a/diag/cin_color.py b/diag/cin_color.py index 324215f..53ed158 100644 --- a/diag/cin_color.py +++ b/diag/cin_color.py @@ -131,7 +131,8 @@ def main(): print(f"[cin/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} " f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)") base = f"cin_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}" - com = {'conv': 'cin', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed} + com = {'conv': 'cin', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, + 'seed': args.seed, 'arch': 'rrog_once_agg_node_compute'} json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w')) json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())}, open(os.path.join(OUT, f"le_color_{base}.json"), 'w')) json.dump({**com, 'det': 1 - float(fails.mean()), diff --git a/diag/esan_color.py b/diag/esan_color.py index 7be050c..b283641 100644 --- a/diag/esan_color.py +++ b/diag/esan_color.py @@ -131,7 +131,8 @@ def main(): print(f"[esan/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} " f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)") base = f"esan_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}" - com = {'conv': 'esan', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed} + com = {'conv': 'esan', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, + 'seed': args.seed, 'arch': 'rrog_once_agg_node_compute'} json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w')) json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())}, open(os.path.join(OUT, f"le_color_{base}.json"), 'w')) diff --git a/diag/eval_rec_ttc.py b/diag/eval_rec_ttc.py new file mode 100644 index 0000000..e0e789e --- /dev/null +++ b/diag/eval_rec_ttc.py @@ -0,0 +1,113 @@ +"""Evaluate deterministic test-time compute for RRoG/TRM-on-GNN checkpoints. + +Loads a trained ``diag/train_rec.py`` checkpoint, overrides ``model.T`` at eval time, +and reports raw MAE on ZINC cycle-count. This separates training-time T from +test-time recursion depth. + +Run: + PYTHONPATH=/home/yurenh2/rrog python3 diag/eval_rec_ttc.py \ + --ckpt runs/ckpt_rec_rrog_full_sig0.0_K1_bestq_T1_ns3_s0.pt --eval_Ts 0 1 2 3 5 8 +""" +import argparse +import json +import os + +import torch + +from diag.train_rec import RecGIN, evaluate, evaluate_trace, loader +from diag.train_cycle import prepare + +OUT = '/home/yurenh2/rrog/runs' + + +def load_model(ckpt, dev): + ck = torch.load(ckpt, weights_only=False, map_location='cpu') + cfg = ck['cfg'] + model = RecGIN( + cfg['n_atom'], + cfg['hidden'], + cfg['T'], + cfg['n_sup'], + sigma=0.0, + grad_mode=cfg.get('grad_mode', 'full'), + agg_layers=cfg.get('agg_layers', 5), + compute_layers=cfg.get('compute_layers', 2), + ).to(dev) + model.load_state_dict(ck['state']) + model.eval() + return model, ck + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--ckpt', required=True) + ap.add_argument('--eval_Ts', type=int, nargs='+', default=[0, 1, 2, 3, 5, 8]) + ap.add_argument('--eval_n_sup', type=int, default=None) + ap.add_argument('--adaptive', action='store_true', + help='choose the first trace step with q_halt > 0, otherwise the last step') + ap.add_argument('--bs', type=int, default=256) + ap.add_argument('--split', choices=['val', 'test'], default='test') + args = ap.parse_args() + + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + model, ck = load_model(args.ckpt, dev) + cfg = ck['cfg'] + ymu, ysd = ck['ymu'], ck['ysd'] + + recs = prepare(args.split) + for r in recs: + r['y'] = (r['y'] - ymu) / ysd + ld = loader(recs, args.bs, False) + + original_T = model.T + original_n_sup = model.n_sup + if args.eval_n_sup is not None: + model.n_sup = args.eval_n_sup + elif args.adaptive and cfg.get('act'): + model.n_sup = cfg.get('halt_max_steps', cfg['n_sup']) + eval_n_sup_used = model.n_sup + + rows = [] + print(f"ckpt={os.path.basename(args.ckpt)} train_T={cfg['T']} train_n_sup={cfg['n_sup']} split={args.split} adaptive={args.adaptive}") + print(f"{'T_eval':>6} {'n_sup':>6} {'mae_c5':>10} {'mae_c6':>10} {'sum':>10} {'avg_steps':>10}") + for T in args.eval_Ts: + model.T = T + if args.adaptive: + mae, avg_steps = evaluate_trace(model, ld, dev, ymu, ysd, model.n_sup, adaptive=True) + else: + mae, _ = evaluate(model, ld, dev, ymu, ysd, K=1, select='none') + avg_steps = float(model.n_sup) + row = { + 'T_eval': T, + 'n_sup_eval': model.n_sup, + 'adaptive': args.adaptive, + 'avg_steps': avg_steps, + 'mae': mae, + 'mae_sum': float(sum(mae)), + } + rows.append(row) + print(f"{T:6d} {model.n_sup:6d} {mae[0]:10.4f} {mae[1]:10.4f} {sum(mae):10.4f} {avg_steps:10.2f}") + + model.T = original_T + model.n_sup = original_n_sup + + os.makedirs(OUT, exist_ok=True) + base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '') + out = { + 'ckpt': args.ckpt, + 'split': args.split, + 'train_T': cfg['T'], + 'train_n_sup': cfg['n_sup'], + 'eval_n_sup': eval_n_sup_used, + 'adaptive': args.adaptive, + 'rows': rows, + } + suffix = "_adaptive" if args.adaptive else "" + fp = os.path.join(OUT, f"ttc_{base}_ns{out['eval_n_sup']}_{args.split}{suffix}.json") + with open(fp, 'w') as f: + json.dump(out, f, indent=2) + print("wrote", fp) + + +if __name__ == '__main__': + main() diff --git a/diag/lyap.py b/diag/lyap.py index 93b90bf..e676568 100644 --- a/diag/lyap.py +++ b/diag/lyap.py @@ -1,12 +1,12 @@ """LE diagnostic for the recursive (TRM-ish) GNN — ports the flossing finding to graphs. -Per-graph top Lyapunov exponent lambda1 of the recursion z <- block(z+h0), via Benettin +Per-graph top Lyapunov exponent lambda1 of the edge-free recursion z <- block(z, ctx), via Benettin power-iteration on a single tangent vector (JVP + renormalize, accumulate log-growth) over the model's n_sup*T recursion steps. Bucket graphs by success/failure (rounded ring counts exact) and compare lambda1 distributions + AUROC(fail | lambda1) — mirroring plot_trm_lyap_hist.py. Hypothesis: failed graphs are MORE chaotic (higher lambda1). -Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_full_..._s0.pt +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_rrog_full_..._s0.pt """ import argparse import numpy as np @@ -21,18 +21,19 @@ except Exception: def build(ck, dev): c = ck['cfg'] - m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode']).to(dev) + m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode'], + agg_layers=c.get('agg_layers', 1), compute_layers=c.get('compute_layers', 2)).to(dev) m.load_state_dict(ck['state']); m.eval() return m, c def lyap1(model, x, ei, n_steps, dev, seed=0): g = torch.Generator(device=dev).manual_seed(seed) - h0 = model.emb(x).detach() - z = torch.zeros_like(h0) - v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12) + ctx = model.aggregate(x, ei).detach() + z = ctx.detach() + v = torch.randn(ctx.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12) def step_fn(zz): - return model.block(zz + h0, ei) + return model.block(zz, ctx) lam = 0.0 for _ in range(n_steps): z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v) @@ -72,10 +73,13 @@ def main(): lams, fails = np.array(lams), np.array(fails) s, f = lams[fails == 0], lams[fails == 1] auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) + sm, ss = (s.mean(), s.std()) if len(s) else (float('nan'), float('nan')) + fm, fs = (f.mean(), f.std()) if len(f) else (float('nan'), float('nan')) + sep = fm - sm if len(s) and len(f) else float('nan') print(f"[{cfg['grad_mode']}] n={len(lams)} fail_rate={fails.mean():.2f} | " - f"lambda1 SUCC mean {s.mean():+.4f} std {s.std():.4f} (n={len(s)}) | " - f"FAIL mean {f.mean():+.4f} std {f.std():.4f} (n={len(f)}) | " - f"sep(fail-succ)={f.mean()-s.mean() if len(s) and len(f) else float('nan'):+.4f} | " + f"lambda1 SUCC mean {sm:+.4f} std {ss:.4f} (n={len(s)}) | " + f"FAIL mean {fm:+.4f} std {fs:.4f} (n={len(f)}) | " + f"sep(fail-succ)={sep:+.4f} | " f"AUROC(fail|lambda1)={auc:.3f} | mean_lambda1={lams.mean():+.4f}") diff --git a/diag/ptrm_color.py b/diag/ptrm_color.py index 4004297..b24097f 100644 --- a/diag/ptrm_color.py +++ b/diag/ptrm_color.py @@ -3,7 +3,7 @@ deterministic / pass@K (conflict-min, ground truth) / lambda-select (min lambda1) / random. -Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_...pt +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_rrog_trm_gin_full_...pt """ import argparse, json, os import numpy as np @@ -18,20 +18,24 @@ OUT = '/home/yurenh2/rrog/runs' def rollout(model, xin, ei, sigma, n_sup, T, dev, seed): gen = torch.Generator(device=dev).manual_seed(seed) - h0 = model.lin_in(xin) - z = torch.zeros_like(h0) - v = torch.randn(h0.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12) - def step(zz): - return model.block(zz + h0, ei) + ctx = model.aggregate(xin, ei) + y, z = ctx, torch.zeros_like(ctx) + state = torch.cat([y, z], dim=-1) + v = torch.randn(state.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12) + def step(ss): + yy, zz = ss.chunk(2, dim=-1) + yy, zz = model.recurse(yy, zz, ctx, noise=False) + return torch.cat([yy, zz], dim=-1) lam = 0.0 for _ in range(n_sup * T): - z_det, Jv = torch.autograd.functional.jvp(step, z, v) + state_det, Jv = torch.autograd.functional.jvp(step, state, v) nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach() - z = z_det.detach() + state = state_det.detach() if sigma > 0: - z = z + sigma * torch.randn(z.shape, generator=gen, device=dev) - lam /= (n_sup * T) - col = model.head(z).argmax(-1) + state = state + sigma * torch.randn(state.shape, generator=gen, device=dev) + lam /= max(n_sup * T, 1) + y, _ = state.chunk(2, dim=-1) + col = model.head(y).argmax(-1) conf = (col[ei[0]] == col[ei[1]]).sum().item() // 2 return conf, lam @@ -47,7 +51,11 @@ def main(): ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg'] deg = torch.tensor(c['deg']) if c.get('deg') else None model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'], - grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev) + grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg, + agg_layers=c.get('agg_layers', 1), + compute_layers=c.get('compute_layers', 2), + compute=(c.get('compute') if c.get('compute') == 'trm' else 'trm'), + attn_heads=c.get('attn_heads', 4)).to(dev) model.load_state_dict(ck['state']); model.eval() nsup, T = c['n_sup'], c['T'] te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), c.get('pe', 'none'), c.get('rwse_k', 16)) @@ -56,6 +64,7 @@ def main(): det = sum(rollout(model, r['xin'].to(dev), r['edge_index'].to(dev), 0.0, nsup, T, dev, 0)[0] == 0 for r in te) / n out = {'conv': c.get('conv', 'gin'), 'pe': c.get('pe', 'none'), 'seed': c.get('seed'), + 'arch': c.get('arch', 'legacy'), 'grad_mode': c['grad_mode'], 'contract': c.get('contract', False), 'det': det, 'sigmas': {}} print(f"[pe={out['pe']} s{out['seed']}] deterministic solve_rate = {det:.3f} (n={n}, K={args.K})") print(f"{'sigma':>6} {'pass@K':>8} {'lam-sel':>8} {'random':>8} {'perRoll':>8} {'AUROC(s|-lam)':>14}") diff --git a/diag/run_archA.sh b/diag/run_archA.sh index 5385152..6cc6042 100644 --- a/diag/run_archA.sh +++ b/diag/run_archA.sh @@ -1,16 +1,15 @@ #!/usr/bin/env bash -# Conv axis (pe=none): gcn, sage, gat. 5 seeds, train+LE+PTRM. (pin GPU at launch.) +# One-shot aggregation axis (pe=none): gcn, sage, gat. 5 seeds, train+LE. set -uo pipefail cd /home/yurenh2/rrog export PYTHONPATH=/home/yurenh2/rrog echo "A start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)" for s in 0 1 2 3 4; do for conv in gcn sage gat; do - ck=runs/ckpt_color_${conv}_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt + ck=runs/ckpt_color_rrog_trm_${conv}_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt echo "== A s$s conv=$conv ==" python3 diag/train_color.py --mode train --conv "$conv" --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train $conv s$s" python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le $conv s$s" - python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $conv s$s" done done echo "doneA=$(date -Is)" diff --git a/diag/run_archB.sh b/diag/run_archB.sh index 317638f..ddc6069 100644 --- a/diag/run_archB.sh +++ b/diag/run_archB.sh @@ -1,21 +1,19 @@ #!/usr/bin/env bash -# GPS transformer backbone + feature axis (gin + lappe / all). 5 seeds. (pin GPU at launch.) +# One-shot GPS encoder + feature axis (gin + lappe / all). 5 seeds. (pin GPU at launch.) set -uo pipefail cd /home/yurenh2/rrog export PYTHONPATH=/home/yurenh2/rrog echo "B start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)" for s in 0 1 2 3 4; do - ck=runs/ckpt_color_gps_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt + ck=runs/ckpt_color_rrog_trm_gps_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt echo "== B s$s conv=gps ==" python3 diag/train_color.py --mode train --conv gps --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train gps s$s" python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le gps s$s" - python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm gps s$s" for pe in lappe all; do - ck2=runs/ckpt_color_gin_full_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt + ck2=runs/ckpt_color_rrog_trm_gin_full_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt echo "== B s$s gin pe=$pe ==" python3 diag/train_color.py --mode train --conv gin --pe "$pe" --p 0.2 --epochs 150 --seed "$s" || echo "!! train $pe s$s" python3 diag/train_color.py --mode le --ckpt "$ck2" || echo "!! le $pe s$s" - python3 diag/ptrm_color.py --ckpt "$ck2" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $pe s$s" done done echo "doneB=$(date -Is)" diff --git a/diag/run_color.sh b/diag/run_color.sh index ad3c406..82c522a 100644 --- a/diag/run_color.sh +++ b/diag/run_color.sh @@ -1,15 +1,12 @@ #!/usr/bin/env bash -# Step-2 (TRM regime, large output): graph 3-coloring, TRM full vs HRM 1-step + LE diagnostic. +# RRoG/TRM-on-GNN: graph 3-coloring, deterministic T=1 lower bound vs T=3 extra compute. set -uo pipefail cd /home/yurenh2/rrog export PYTHONPATH=/home/yurenh2/rrog echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" -for gm in full 1step; do - echo "===== train $gm =====" - python3 diag/train_color.py --mode train --grad_mode "$gm" --p 0.2 --epochs 150 --seed 0 \ - || echo "!! train $gm failed" +for T in 1 3; do + echo "===== train T=$T =====" + python3 diag/train_color.py --mode train --grad_mode full --T "$T" --p 0.2 --epochs 150 --seed 0 \ + || echo "!! train T=$T failed" done -echo "===== LE diagnostic (lambda1: solved vs unsolved) =====" -python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_full_n50_k3_p0.2_T3_ns3_s0.pt || echo "!! le full failed" -python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_1step_n50_k3_p0.2_T3_ns3_s0.pt || echo "!! le 1step failed" echo "done=$(date -Is)" diff --git a/diag/run_le.sh b/diag/run_le.sh index 8fc31ea..07d6a63 100644 --- a/diag/run_le.sh +++ b/diag/run_le.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Step-2(iii): TRM-ish GIN full-recursion vs 1-step-gradient, then LE diagnostic on each. +# RRoG/TRM-on-GNN GIN full-recursion vs 1-step-gradient, then LE diagnostic on each. set -uo pipefail cd /home/yurenh2/rrog export PYTHONPATH=/home/yurenh2/rrog @@ -7,6 +7,6 @@ echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 --epochs 200 --seed 0 || echo "!! train full failed" python3 diag/train_rec.py --grad_mode 1step --sigma 0 --K 1 --epochs 200 --seed 0 || echo "!! train 1step failed" echo "===== LE diagnostic (lambda1: success vs failure) =====" -python3 diag/lyap.py --ckpt runs/ckpt_rec_full_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap full failed" -python3 diag/lyap.py --ckpt runs/ckpt_rec_1step_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap 1step failed" +python3 diag/lyap.py --ckpt runs/ckpt_rec_rrog_full_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap full failed" +python3 diag/lyap.py --ckpt runs/ckpt_rec_rrog_1step_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap 1step failed" echo "done=$(date -Is)" diff --git a/diag/run_pe.sh b/diag/run_pe.sh index 8350160..1e587ce 100644 --- a/diag/run_pe.sh +++ b/diag/run_pe.sh @@ -1,6 +1,5 @@ #!/usr/bin/env bash -# Roadmap #1: RRoG on PE-augmented backbone. GIN vs GIN+RWSE on coloring: -# does RRoG-noise add headroom on top of static structural encoding, or is it redundant? +# RRoG on PE-augmented one-shot GIN encoder. GIN vs GIN+RWSE on coloring. set -uo pipefail cd /home/yurenh2/rrog export PYTHONPATH=/home/yurenh2/rrog @@ -12,13 +11,7 @@ for pe in none rwse; do done echo "===== LE (full, both pe) =====" for pe in none rwse; do - python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \ + python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_rrog_trm_gin_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \ || echo "!! le $pe failed" done -echo "===== PTRM noise + lambda-select (both pe) =====" -for pe in none rwse; do - echo "--- pe=$pe ---" - python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \ - --K 16 --n_graphs 150 --sigmas 0.1 0.2 0.4 || echo "!! ptrm $pe failed" -done echo "done=$(date -Is)" diff --git a/diag/run_pe2.sh b/diag/run_pe2.sh index db7bd8a..dc9f88b 100644 --- a/diag/run_pe2.sh +++ b/diag/run_pe2.sh @@ -1,17 +1,15 @@ #!/usr/bin/env bash -# Roadmap #2: RRoG on a GSN-style motif backbone (per-node K3/wedge substructure counts). -# full-recursion, 5 seeds; train + LE + PTRM(sigma=0.2). Aggregate vs none/rwse. +# RRoG on a GSN-style motif-augmented one-shot encoder. Full recursion, 5 seeds; train + LE. set -uo pipefail cd /home/yurenh2/rrog export PYTHONPATH=/home/yurenh2/rrog echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" for s in 0 1 2 3 4; do - ck=runs/ckpt_color_full_gsn_n50_k3_p0.2_T3_ns3_s${s}.pt + ck=runs/ckpt_color_rrog_trm_gin_full_gsn_n50_k3_p0.2_T3_ns3_s${s}.pt echo "===== seed=$s full gsn =====" python3 diag/train_color.py --mode train --grad_mode full --pe gsn --p 0.2 --epochs 150 --seed "$s" \ || echo "!! train gsn s$s failed" python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le gsn s$s failed" - python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm gsn s$s failed" done echo "===== AGGREGATE (all pe: none/rwse/gsn) =====" python3 diag/aggregate.py diff --git a/diag/run_pe3.sh b/diag/run_pe3.sh index a978393..2a6744b 100644 --- a/diag/run_pe3.sh +++ b/diag/run_pe3.sh @@ -1,22 +1,20 @@ #!/usr/bin/env bash -# Roadmap #3 (subgraph/ego features) + #4 (IGNN-style forced contraction), 5 seeds. +# Subgraph/ego features + forced contraction, 5 seeds. Deterministic train + LE. # #3: full --pe sub. #4: full --pe none --contract (vs free full/none baseline). set -uo pipefail cd /home/yurenh2/rrog export PYTHONPATH=/home/yurenh2/rrog echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" for s in 0 1 2 3 4; do - ck=runs/ckpt_color_full_sub_n50_k3_p0.2_T3_ns3_s${s}.pt + ck=runs/ckpt_color_rrog_trm_gin_full_sub_n50_k3_p0.2_T3_ns3_s${s}.pt echo "===== seed=$s #3 full sub =====" python3 diag/train_color.py --mode train --grad_mode full --pe sub --p 0.2 --epochs 150 --seed "$s" || echo "!! train sub s$s failed" python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le sub s$s failed" - python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm sub s$s failed" - ck2=runs/ckpt_color_full_none_ctr_n50_k3_p0.2_T3_ns3_s${s}.pt + ck2=runs/ckpt_color_rrog_trm_gin_full_none_ctr_n50_k3_p0.2_T3_ns3_s${s}.pt echo "===== seed=$s #4 full none --contract =====" python3 diag/train_color.py --mode train --grad_mode full --pe none --contract --p 0.2 --epochs 150 --seed "$s" || echo "!! train ctr s$s failed" python3 diag/train_color.py --mode le --ckpt "$ck2" || echo "!! le ctr s$s failed" - python3 diag/ptrm_color.py --ckpt "$ck2" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm ctr s$s failed" done echo "===== AGGREGATE =====" python3 diag/aggregate.py diff --git a/diag/run_pna.sh b/diag/run_pna.sh index 897315d..157d8ac 100644 --- a/diag/run_pna.sh +++ b/diag/run_pna.sh @@ -4,10 +4,9 @@ cd /home/yurenh2/rrog export PYTHONPATH=/home/yurenh2/rrog echo "PNA start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)" for s in 0 1 2 3 4; do - ck=runs/ckpt_color_pna_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt + ck=runs/ckpt_color_rrog_trm_pna_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt echo "== pna s$s ==" python3 diag/train_color.py --mode train --conv pna --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train pna s$s" python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le pna s$s" - python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm pna s$s" done echo "donePNA=$(date -Is)" diff --git a/diag/run_rec.sh b/diag/run_rec.sh index 632498d..c436c4e 100644 --- a/diag/run_rec.sh +++ b/diag/run_rec.sh @@ -1,13 +1,12 @@ #!/usr/bin/env bash -# Step-2: recursive GNN + full PTRM on ZINC ring-counting. Does per-step noise + best-Q@K -# selection break the 1-WL counting ceiling that input-RNI couldn't? Averaging vs selection. +# RRoG/TRM-on-GNN on ZINC ring-counting. Deterministic view-only vs recursive compute. set -uo pipefail cd /home/yurenh2/rrog export PYTHONPATH=/home/yurenh2/rrog echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" -run() { echo "===== $* ====="; python3 diag/train_rec.py "$@" --epochs 200 --seed 0 || echo "!! FAILED $*"; } -run --sigma 0 --K 1 --select bestq # deterministic recursive baseline (~1-WL ceiling) -run --sigma 0.1 --K 8 --select none # averaging over rollouts (RNI-style null) -run --sigma 0.1 --K 8 --select bestq # PTRM-proper: per-step noise + best-Q@K selection -run --sigma 0.2 --K 16 --select bestq # scaled noise + rollouts +run() { echo "===== $* ====="; python3 diag/train_rec.py "$@" --epochs 40 --hidden 64 --bs 256 --seed 0 || echo "!! FAILED $*"; } +run --grad_mode full --T 0 --n_sup 3 --sigma 0 --K 1 --select none +run --grad_mode full --T 1 --n_sup 3 --sigma 0 --K 1 --select none +run --grad_mode full --T 3 --n_sup 3 --sigma 0 --K 1 --select none +run --grad_mode full --T 1 --n_sup 3 --act --halt_max_steps 8 --halt_target binary --sigma 0 --K 1 --select none echo "done=$(date -Is)" diff --git a/diag/run_seeds.sh b/diag/run_seeds.sh index 77f22c8..85b67d1 100644 --- a/diag/run_seeds.sh +++ b/diag/run_seeds.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Harden coloring results with 5 seeds: full-vs-1step solve, LE AUROC, PTRM(sigma=0.2) for none/rwse. +# Harden corrected RRoG coloring results with 5 seeds: full-vs-1step solve + LE AUROC. set -uo pipefail cd /home/yurenh2/rrog export PYTHONPATH=/home/yurenh2/rrog @@ -7,14 +7,11 @@ echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" for s in 0 1 2 3 4; do for cfg in "full none" "full rwse" "1step none"; do set -- $cfg; gm=$1; pe=$2 - ck=runs/ckpt_color_${gm}_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt + ck=runs/ckpt_color_rrog_trm_gin_${gm}_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt echo "===== seed=$s $gm $pe =====" python3 diag/train_color.py --mode train --grad_mode "$gm" --pe "$pe" --p 0.2 --epochs 150 --seed "$s" \ || echo "!! train $gm $pe s$s failed" python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le $gm $pe s$s failed" - if [ "$gm" = "full" ]; then - python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $pe s$s failed" - fi done done echo "===== AGGREGATE =====" diff --git a/diag/train_color.py b/diag/train_color.py index 36f8496..b9d0c23 100644 --- a/diag/train_color.py +++ b/diag/train_color.py @@ -1,7 +1,11 @@ -"""Recursive (TRM-ish) GNN graph 3-coloring with swappable BACKBONE for the RRoG roadmap. +"""RRoG/TRM-on-GNN graph 3-coloring. ---conv gin|gcn|sage|gat|gps : message-passing operator (gps = GraphGPS local MPNN + global - attention = TRM's original transformer backbone, on the graph). +The graph is encoded once into a fixed per-node context. Recursion then refines hidden state +with a shared compute block that never reads edge_index. This is the RRoG split: +the GNN encoder supplies the view/context x, TRM-style recurrence supplies computation. + +--conv gin|gcn|sage|gat|gps : message-passing operator used only by the one-shot encoder + (gps = GraphGPS local MPNN + global attention). --pe none|rwse|gsn|sub|lappe|all : input structural features (random sym-break [+ encoding]). --contract : reverse-flossing lambda-penalty during training (force contraction; roadmap #4). --grad_mode full|1step : TRM full recursion vs HRM 1-step gradient. @@ -142,48 +146,83 @@ def make_conv(conv, hidden, deg=None): class RecGINColor(nn.Module): - def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full', sigma=0.0, conv='gin', deg=None): + def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full', + sigma=0.0, conv='gin', deg=None, agg_layers=4, compute_layers=None, + compute='trm', attn_heads=4): super().__init__() self.conv_type = conv + self.agg_layers = agg_layers + self.compute_layers = compute_layers or inner + self.compute = compute + self.attn_heads = attn_heads self.lin_in = nn.Linear(in_dim, hidden) - self.convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(inner)]) - self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)]) + self.agg_convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(agg_layers)]) + self.agg_bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(agg_layers)]) + if compute not in ('trm',): + raise ValueError(compute) + core = [] + d = hidden + for _ in range(self.compute_layers - 1): + core += [nn.Linear(d, hidden), nn.GELU()] + d = hidden + core.append(nn.Linear(d, hidden)) + self.core_norm = nn.LayerNorm(hidden) + self.core = nn.Sequential(*core) + nn.init.zeros_(self.core[-1].weight) + nn.init.zeros_(self.core[-1].bias) self.head = nn.Linear(hidden, k) self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma - def block(self, z, ei, batch=None): + def aggregate(self, xin, ei, batch=None): if self.conv_type == 'gps' and batch is None: - batch = z.new_zeros(z.size(0), dtype=torch.long) - for conv, bn in zip(self.convs, self.bns): - z = conv(z, ei, batch) if self.conv_type == 'gps' else conv(z, ei) - z = bn(z).relu() - return z - - def _inner(self, z, h0, ei, noise, batch): - z = self.block(z + h0, ei, batch) + batch = xin.new_zeros(xin.size(0), dtype=torch.long) + h = self.lin_in(xin) + for conv, bn in zip(self.agg_convs, self.agg_bns): + h = conv(h, ei, batch) if self.conv_type == 'gps' else conv(h, ei) + h = bn(h).relu() + return h + + def core_step(self, combined, state): + """Shared TRM compute core. Deliberately edge-free.""" + return state + self.core(self.core_norm(combined)) + + def _z_step(self, y, z, ctx, noise): + z = self.core_step(ctx + y + z, z) if noise and self.sigma > 0: z = z + self.sigma * torch.randn_like(z) return z - def recurse(self, z, h0, ei, noise, batch, one_step=False): + def _y_step(self, y, z, noise): + y = self.core_step(y + z, y) + if noise and self.sigma > 0: + y = y + self.sigma * torch.randn_like(y) + return y + + def recurse(self, y, z, ctx, noise, one_step=False): + if self.T == 0: + return y, z if one_step: with torch.no_grad(): for _ in range(self.T - 1): - z = self._inner(z, h0, ei, noise, batch) + z = self._z_step(y, z, ctx, noise) z = z.detach() - return self._inner(z, h0, ei, noise, batch) + z = self._z_step(y, z, ctx, noise) + y = self._y_step(y, z, noise) + return y, z for _ in range(self.T): - z = self._inner(z, h0, ei, noise, batch) - return z + z = self._z_step(y, z, ctx, noise) + y = self._y_step(y, z, noise) + return y, z def forward(self, xin, ei, batch=None, noise=False): - h0 = self.lin_in(xin) - z = torch.zeros_like(h0) + ctx = self.aggregate(xin, ei, batch) + y = ctx + z = torch.zeros_like(ctx) outs = [] for s in range(self.n_sup): - z = self.recurse(z, h0, ei, noise, batch, one_step=(self.grad_mode == '1step')) - outs.append(self.head(z)) - z = z.detach() + y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step')) + outs.append(self.head(y)) + y, z = y.detach(), z.detach() return outs @@ -206,15 +245,17 @@ def solve_stats(model, recs, dev, sample=None): def lyap1(model, xin, ei, n_steps, dev, seed=0): g = torch.Generator(device=dev).manual_seed(seed) - h0 = model.lin_in(xin).detach() - z = torch.zeros_like(h0) - v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12) - def step_fn(zz): - return model.block(zz + h0, ei) + ctx = model.aggregate(xin, ei).detach() + state = torch.cat([ctx, torch.zeros_like(ctx)], dim=-1).detach() + v = torch.randn(state.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12) + def step_fn(ss): + y, z = ss.chunk(2, dim=-1) + y, z = model.recurse(y, z, ctx, noise=False) + return torch.cat([y, z], dim=-1) lam = 0.0 for _ in range(n_steps): - z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v) - z = z_next.detach(); nv = Jv.norm() + state_next, Jv = torch.autograd.functional.jvp(step_fn, state, v) + state = state_next.detach(); nv = Jv.norm() lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach() return lam / n_steps @@ -246,11 +287,16 @@ def run_le(model, recs, dev, n_steps, n_graphs=300): def lyap_penalty(model, x, ei, batch, target=-0.5): - h0 = model.lin_in(x) + ctx = model.aggregate(x, ei, batch) with torch.no_grad(): - zr = model.recurse(torch.zeros_like(h0), h0.detach(), ei, False, batch) - v = torch.randn_like(zr); v = v / (v.norm() + 1e-12) - _, Jv = torch.autograd.functional.jvp(lambda zz: model.block(zz + h0, ei, batch), zr, v, create_graph=True) + yr, zr = model.recurse(ctx.detach(), torch.zeros_like(ctx).detach(), ctx.detach(), False) + state = torch.cat([yr, zr], dim=-1) + v = torch.randn_like(state); v = v / (v.norm() + 1e-12) + def step_fn(ss): + y, z = ss.chunk(2, dim=-1) + y, z = model.recurse(y, z, ctx, noise=False) + return torch.cat([y, z], dim=-1) + _, Jv = torch.autograd.functional.jvp(step_fn, state, v, create_graph=True) return (torch.log(Jv.norm() + 1e-12) - target) ** 2 @@ -267,6 +313,10 @@ def main(): ap.add_argument('--p', type=float, default=0.2); ap.add_argument('--r', type=int, default=8) ap.add_argument('--hidden', type=int, default=128); ap.add_argument('--T', type=int, default=3) ap.add_argument('--n_sup', type=int, default=3); ap.add_argument('--epochs', type=int, default=150) + ap.add_argument('--agg_layers', type=int, default=4) + ap.add_argument('--compute_layers', type=int, default=2) + ap.add_argument('--compute', choices=['trm'], default='trm') + ap.add_argument('--attn_heads', type=int, default=4) ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--bs', type=int, default=32) ap.add_argument('--seed', type=int, default=0) args = ap.parse_args() @@ -280,13 +330,18 @@ def main(): c.get('pe', 'none'), c.get('rwse_k', 16)) deg = torch.tensor(c['deg']) if c.get('deg') else None model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'], - grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev) + grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg, + agg_layers=c.get('agg_layers', 1), + compute_layers=c.get('compute_layers', 2), + compute=(c.get('compute') if c.get('compute') == 'trm' else 'trm'), + attn_heads=c.get('attn_heads', 4)).to(dev) model.load_state_dict(ck['state']); model.eval() res = run_le(model, te, dev, c['n_sup'] * c['T']) base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '') with open(os.path.join(OUT, f"le_{base}.json"), 'w') as fjs: json.dump({'conv': c.get('conv', 'gin'), 'grad_mode': c['grad_mode'], 'pe': c.get('pe', 'none'), - 'contract': c.get('contract', False), 'seed': c.get('seed'), **res}, fjs, indent=2) + 'contract': c.get('contract', False), 'seed': c.get('seed'), + 'arch': c.get('arch', 'legacy'), **res}, fjs, indent=2) return te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), args.pe, args.rwse_k) @@ -296,7 +351,9 @@ def main(): trl = DataLoader(data, batch_size=args.bs, shuffle=True, drop_last=True) deg = deg_hist(tr) if args.conv == 'pna' else None model = RecGINColor(in_dim, args.hidden, args.k, args.T, args.n_sup, - grad_mode=args.grad_mode, conv=args.conv, deg=deg).to(dev) + grad_mode=args.grad_mode, conv=args.conv, deg=deg, + agg_layers=args.agg_layers, compute_layers=args.compute_layers, + compute=args.compute, attn_heads=args.attn_heads).to(dev) opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) @@ -329,15 +386,18 @@ def main(): print(f"ep{ep+1} solve_rate={sr:.3f} mean_conflicts={mc:.2f}", flush=True) sfx = ('_ctr' if args.contract else '') - tag = f"color_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}" + tag = f"color_rrog_{args.compute}_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}" rep = {'task': 'graph3coloring', 'tag': tag, **vars(args), 'in_dim': in_dim, - 'sec': round(time.time() - t0, 1), **best} + 'arch': 'rrog_once_agg_hidden_compute', 'sec': round(time.time() - t0, 1), **best} print(f"[{tag}] best solve_rate={best.get('solve_rate')} @ep{best.get('ep')} ({rep['sec']}s)") with open(os.path.join(OUT, f"{tag}.json"), 'w') as f: json.dump(rep, f, indent=2) torch.save({'state': best_state, 'cfg': {'in_dim': in_dim, 'hidden': args.hidden, 'k': args.k, 'T': args.T, 'n_sup': args.n_sup, 'grad_mode': args.grad_mode, 'pe': args.pe, 'rwse_k': args.rwse_k, 'contract': args.contract, 'conv': args.conv, 'seed': args.seed, + 'agg_layers': args.agg_layers, 'compute_layers': args.compute_layers, + 'compute': args.compute, 'attn_heads': args.attn_heads, + 'arch': 'rrog_once_agg_hidden_compute', 'deg': (deg.tolist() if deg is not None else None)}}, os.path.join(OUT, f"ckpt_{tag}.pt")) print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt")) diff --git a/diag/train_cycle.py b/diag/train_cycle.py index d2342f3..598e349 100644 --- a/diag/train_cycle.py +++ b/diag/train_cycle.py @@ -24,9 +24,14 @@ from torch_geometric.loader import DataLoader from torch_geometric.utils import to_networkx from torch_geometric.nn import GINConv, GCNConv, global_add_pool -ROOT = '/home/yurenh2/rrog/data/zinc' -CACHE = '/home/yurenh2/rrog/data/cycle_cache' -OUT = '/home/yurenh2/rrog/runs' +PROJECT_ROOT = os.environ.get( + 'RROG_ROOT', + os.path.abspath(os.path.join(os.path.dirname(__file__), '..')), +) +DATA_ROOT = os.environ.get('RROG_DATA_DIR', os.path.join(PROJECT_ROOT, 'data')) +OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs')) +ROOT = os.path.join(DATA_ROOT, 'zinc') +CACHE = os.path.join(DATA_ROOT, 'cycle_cache') RWSE_K = 16 diff --git a/diag/train_rec.py b/diag/train_rec.py index 9866f28..9db7eb1 100644 --- a/diag/train_rec.py +++ b/diag/train_rec.py @@ -1,11 +1,10 @@ -"""Step-2: Recursive (TRM-ish) GNN on ZINC ring-counting + optional PTRM noise/selection. +"""Step-2: RRoG/TRM-on-GNN for ZINC ring-counting. -Recurrent shared-weight GIN block, deep-supervised over n_sup steps (TRM-style: carry latent -detached between steps). --grad_mode controls the LAST supervision step's recursion: +The graph is encoded once with a GIN encoder. A shared edge-free node-wise compute block then +refines hidden state over n_sup*T recurrent steps (TRM-style: carry latent detached between +deep-supervision steps). --grad_mode controls the LAST supervision step's recursion: full : backprop through all T inner recursions (TRM) 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient) -Optional per-step Gaussian noise (sigma) + K stochastic rollouts selected by a value head -(best-Q@K) for the PTRM experiments. Saves a checkpoint for the LE diagnostic (diag/lyap.py). Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 """ @@ -14,66 +13,218 @@ import numpy as np import torch import torch.nn as nn from torch_geometric.loader import DataLoader -from torch_geometric.data import Data -from torch_geometric.nn import GINConv, global_add_pool +from torch_geometric.data import Batch, Data +from torch_geometric.nn import ( + APPNP, + ARMAConv, + ChebConv, + FiLMConv, + GATv2Conv, + GCNConv, + GENConv, + GINEConv, + GINConv, + GraphConv, + MFConv, + PNAConv, + ResGatedGraphConv, + SAGEConv, + SGConv, + TAGConv, + TransformerConv, + global_add_pool, +) +from torch_geometric.utils import degree from diag.train_cycle import prepare -OUT = '/home/yurenh2/rrog/runs' +PROJECT_ROOT = os.environ.get( + 'RROG_ROOT', + os.path.abspath(os.path.join(os.path.dirname(__file__), '..')), +) +OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs')) +SUPPORTED_VIEWS = [ + 'gin', 'gine', 'gcn', 'graphsage', 'gatv2', 'graphconv', 'transformer', 'pna', + 'gen', 'film', 'resgated', 'tag', 'sgc', 'cheb', 'arma', 'mf', 'appnp', +] -def loader(recs, bs, shuffle, drop_last=False): - data = [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2), +def data_list(recs): + return [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs] + + +def loader(recs, bs, shuffle, drop_last=False): + data = recs if recs and isinstance(recs[0], Data) else data_list(recs) return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last) +def degree_histogram(data): + max_degree = 0 + degs = [] + for graph in data: + deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long) + degs.append(deg) + if deg.numel(): + max_degree = max(max_degree, int(deg.max().item())) + hist = torch.zeros(max_degree + 1, dtype=torch.long) + for deg in degs: + hist += torch.bincount(deg, minlength=hist.numel()) + return hist + + +def make_view_layer(view, hidden, deg): + if view == 'gin': + return GINConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) + if view == 'gine': + return GINEConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), + train_eps=True, edge_dim=hidden) + if view == 'gcn': + return GCNConv(hidden, hidden) + if view == 'graphsage': + return SAGEConv(hidden, hidden) + if view == 'gatv2': + return GATv2Conv(hidden, hidden, heads=4, concat=False) + if view == 'graphconv': + return GraphConv(hidden, hidden) + if view == 'transformer': + return TransformerConv(hidden, hidden, heads=4, concat=False) + if view == 'pna': + if deg is None: + raise ValueError('PNA view requires a training-set degree histogram') + return PNAConv( + hidden, hidden, + aggregators=['mean', 'min', 'max', 'std'], + scalers=['identity', 'amplification', 'attenuation'], + deg=deg, + ) + if view == 'gen': + return GENConv(hidden, hidden) + if view == 'film': + return FiLMConv(hidden, hidden) + if view == 'resgated': + return ResGatedGraphConv(hidden, hidden) + if view == 'tag': + return TAGConv(hidden, hidden, K=3) + if view == 'sgc': + return SGConv(hidden, hidden, K=2, cached=False) + if view == 'cheb': + return ChebConv(hidden, hidden, K=3) + if view == 'arma': + return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2) + if view == 'mf': + return MFConv(hidden, hidden) + if view == 'appnp': + return APPNP(K=5, alpha=0.1) + raise ValueError(f'unsupported view: {view}') + + class RecGIN(nn.Module): - def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, grad_mode='full'): + def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, + grad_mode='full', agg_layers=5, compute_layers=None, view='gin', deg=None): super().__init__() + self.view = view + self.agg_layers = agg_layers + self.compute_layers = compute_layers or inner self.emb = nn.Embedding(n_atom, hidden) - self.convs = nn.ModuleList([GINConv(nn.Sequential( - nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) - for _ in range(inner)]) - self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)]) + self.edge_emb = nn.Embedding(1, hidden) if view == 'gine' else None + self.agg_convs = nn.ModuleList() + for _ in range(agg_layers): + self.agg_convs.append(make_view_layer(view, hidden, deg)) + self.agg_bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(agg_layers)]) + core = [] + d = hidden + for _ in range(self.compute_layers - 1): + core += [nn.Linear(d, hidden), nn.GELU()] + d = hidden + core.append(nn.Linear(d, hidden)) + self.core_norm = nn.LayerNorm(hidden) + self.core = nn.Sequential(*core) + nn.init.zeros_(self.core[-1].weight) + nn.init.zeros_(self.core[-1].bias) self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2)) self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)) + with torch.no_grad(): + self.qhead[-1].weight.zero_() + self.qhead[-1].bias.fill_(-5.0) self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode - def block(self, z, ei): - for conv, bn in zip(self.convs, self.bns): - z = bn(conv(z, ei)).relu() - return z + def aggregate(self, x, ei): + h = self.emb(x) + for conv, bn in zip(self.agg_convs, self.agg_bns): + if self.view == 'gine': + edge_attr = self.edge_emb(torch.zeros(ei.size(1), dtype=torch.long, device=ei.device)) + h = bn(conv(h, ei, edge_attr)).relu() + else: + h = bn(conv(h, ei)).relu() + return h + + def core_step(self, combined, state): + """Shared TRM compute core. Deliberately edge-free.""" + return state + self.core(self.core_norm(combined)) - def _inner(self, z, h0, ei, noise): - z = self.block(z + h0, ei) + def _z_step(self, y, z, ctx, noise): + z = self.core_step(ctx + y + z, z) if noise and self.sigma > 0: z = z + self.sigma * torch.randn_like(z) return z - def recurse(self, z, h0, ei, noise, one_step=False): + def _y_step(self, y, z, noise): + y = self.core_step(y + z, y) + if noise and self.sigma > 0: + y = y + self.sigma * torch.randn_like(y) + return y + + def recurse(self, y, z, ctx, noise, one_step=False): + if self.T == 0: + return y, z if one_step: # HRM 1-step gradient with torch.no_grad(): for _ in range(self.T - 1): - z = self._inner(z, h0, ei, noise) + z = self._z_step(y, z, ctx, noise) z = z.detach() - return self._inner(z, h0, ei, noise) # only last inner carries grad + z = self._z_step(y, z, ctx, noise) # only last inner carries grad + y = self._y_step(y, z, noise) + return y, z for _ in range(self.T): # TRM full recursion - z = self._inner(z, h0, ei, noise) - return z + z = self._z_step(y, z, ctx, noise) + y = self._y_step(y, z, noise) + return y, z + + def predict(self, y, batch): + pooled = global_add_pool(y, batch) + return self.head(pooled), self.qhead(pooled).view(-1) + + def forward_trace(self, x, ei, batch, steps, noise=False): + ctx = self.aggregate(x, ei) + y = ctx + z = torch.zeros_like(ctx) + preds, q_logits = [], [] + for s in range(steps): + y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step')) + pred, q = self.predict(y, batch) + preds.append(pred) + q_logits.append(q) + if s < steps - 1: + y, z = y.detach(), z.detach() + return preds, q_logits def forward(self, x, ei, batch, noise=False): - h0 = self.emb(x) - z = torch.zeros_like(h0) + ctx = self.aggregate(x, ei) + y = ctx + z = torch.zeros_like(ctx) preds = [] for s in range(self.n_sup): if s < self.n_sup - 1: with torch.no_grad(): - z = self.recurse(z, h0, ei, noise) - z = z.detach() + y, z = self.recurse(y, z, ctx, noise) + y, z = y.detach(), z.detach() else: - z = self.recurse(z, h0, ei, noise, one_step=(self.grad_mode == '1step')) - preds.append(self.head(global_add_pool(z, batch))) - q = self.qhead(global_add_pool(z, batch)).view(-1) + y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step')) + pred, _ = self.predict(y, batch) + preds.append(pred) + _, q = self.predict(y, batch) return preds, q @@ -102,6 +253,113 @@ def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'): return (ae / n).tolist(), (ae_or / n).tolist() +@torch.no_grad() +def evaluate_trace(model, ld, dev, ymu, ysd, steps, adaptive=False): + model.eval() + ysd_d, ymu_d = ysd.to(dev), ymu.to(dev) + ae = torch.zeros(2) + n = 0 + step_sum = 0.0 + for b in ld: + b = b.to(dev) + preds, q_logits = model.forward_trace(b.x, b.edge_index, b.batch, steps, noise=False) + P = torch.stack(preds, dim=0) + if adaptive: + Q = torch.stack(q_logits, dim=0) + halted = Q > 0 + any_halt = halted.any(dim=0) + first_halt = halted.to(torch.int64).argmax(dim=0) + fallback = torch.full_like(first_halt, steps - 1) + idx = torch.where(any_halt, first_halt, fallback) + chosen = P[idx, torch.arange(P.size(1), device=dev)] + step_sum += (idx.to(torch.float32) + 1).sum().item() + else: + chosen = P[-1] + step_sum += steps * b.num_graphs + ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + n += b.num_graphs + return (ae / n).tolist(), step_sum / max(n, 1) + + +def _split_nodes(t, ptr): + return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)] + + +def act_train_step(model, state, replacement_batch, opt, dev, args): + replacement = replacement_batch.to_data_list() + batch_size = len(replacement) + if state is None: + state = { + 'graphs': [None for _ in range(batch_size)], + 'y': [None for _ in range(batch_size)], + 'z': [None for _ in range(batch_size)], + 'steps': torch.zeros(batch_size, dtype=torch.long, device=dev), + 'halted': torch.ones(batch_size, dtype=torch.bool, device=dev), + } + + halted_cpu = state['halted'].detach().cpu().tolist() + for i, halted in enumerate(halted_cpu): + if halted: + state['graphs'][i] = replacement[i] + + b = Batch.from_data_list(state['graphs']).to(dev) + ctx = model.aggregate(b.x, b.edge_index) + ptr = b.ptr + y_parts, z_parts = [], [] + for i in range(batch_size): + start, end = ptr[i].item(), ptr[i + 1].item() + if halted_cpu[i] or state['y'][i] is None: + y_parts.append(ctx[start:end]) + z_parts.append(torch.zeros_like(ctx[start:end])) + else: + y_parts.append(state['y'][i].to(dev)) + z_parts.append(state['z'][i].to(dev)) + y = torch.cat(y_parts, dim=0) + z = torch.cat(z_parts, dim=0) + + opt.zero_grad() + y, z = model.recurse(y, z, ctx, noise=False, one_step=(model.grad_mode == '1step')) + pred, q = model.predict(y, b.batch) + per_graph_err = (pred - b.y).abs().mean(1) + pred_loss = per_graph_err.mean() + with torch.no_grad(): + if args.halt_target == 'binary': + halt_target = (per_graph_err <= args.halt_norm_threshold).to(q.dtype) + else: + halt_target = torch.sigmoid((args.halt_norm_threshold - per_graph_err) / args.halt_temp) + q_loss = nn.functional.binary_cross_entropy_with_logits(q, halt_target) + loss = pred_loss + 0.5 * args.lam_q * q_loss + y_det, z_det = y.detach(), z.detach() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + + state['y'] = _split_nodes(y_det, ptr) + state['z'] = _split_nodes(z_det, ptr) + with torch.no_grad(): + was_halted = state['halted'] + steps = torch.where(was_halted, torch.zeros_like(state['steps']), state['steps']) + 1 + halted = (steps >= args.halt_max_steps) | (q.detach() > 0) + if args.halt_exploration_prob > 0 and args.halt_max_steps > 1: + explore = torch.rand_like(q) < args.halt_exploration_prob + min_steps = torch.where( + explore, + torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev), + torch.zeros_like(steps), + ) + halted = halted & (steps >= min_steps) + state['steps'] = steps + state['halted'] = halted + + return state, { + 'loss': float(loss.detach().cpu()), + 'pred_loss': float(pred_loss.detach().cpu()), + 'q_loss': float(q_loss.detach().cpu()), + 'halted_frac': float(state['halted'].to(torch.float32).mean().detach().cpu()), + 'steps': float(state['steps'].to(torch.float32).mean().detach().cpu()), + } + + def main(): ap = argparse.ArgumentParser() ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') @@ -111,14 +369,27 @@ def main(): ap.add_argument('--T', type=int, default=3) ap.add_argument('--n_sup', type=int, default=3) ap.add_argument('--hidden', type=int, default=128) + ap.add_argument('--agg_layers', type=int, default=5) + ap.add_argument('--compute_layers', type=int, default=2) + ap.add_argument('--view', choices=SUPPORTED_VIEWS, default='gin') ap.add_argument('--epochs', type=int, default=200) ap.add_argument('--lr', type=float, default=1e-3) ap.add_argument('--bs', type=int, default=128) ap.add_argument('--lam_q', type=float, default=1.0) + ap.add_argument('--act', action='store_true', + help='train all recurrent depths up to halt_max_steps and train qhead as a halt head') + ap.add_argument('--halt_max_steps', type=int, default=8) + ap.add_argument('--halt_norm_threshold', type=float, default=0.30) + ap.add_argument('--halt_temp', type=float, default=0.10) + ap.add_argument('--halt_target', choices=['soft', 'binary'], default='soft') + ap.add_argument('--halt_exploration_prob', type=float, default=0.1) + ap.add_argument('--loss_mode', choices=['last', 'trace'], default='trace') ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--device', default='auto') args = ap.parse_args() torch.manual_seed(args.seed); np.random.seed(args.seed) - dev = 'cuda' if torch.cuda.is_available() else 'cpu' + dev = 'cuda' if args.device == 'auto' and torch.cuda.is_available() else ( + 'cpu' if args.device == 'auto' else args.device) os.makedirs(OUT, exist_ok=True) tr, va, te = prepare('train'), prepare('val'), prepare('test') @@ -127,45 +398,91 @@ def main(): for recs in (tr, va, te): for r in recs: r['y'] = (r['y'] - ymu) / ysd - trl = loader(tr, args.bs, True, drop_last=True) + train_data = data_list(tr) + trl = loader(train_data, args.bs, True, drop_last=True) val, tel = loader(va, 256, False), loader(te, 256, False) - model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode).to(dev) + deg = degree_histogram(train_data) if args.view == 'pna' else None + model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode, + agg_layers=args.agg_layers, compute_layers=args.compute_layers, + view=args.view, deg=deg).to(dev) opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) l1 = nn.L1Loss() + act_steps = max(1, args.halt_max_steps) - t0 = time.time(); best_val = 9e9; best = {}; best_state = None + t0 = time.time(); best_val = 9e9; best = {}; best_state = None; act_state = None for ep in range(args.epochs): model.train() + act_metrics = [] for b in trl: - b = b.to(dev); opt.zero_grad() - preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) - loss = sum(l1(p, b.y) for p in preds) / len(preds) - with torch.no_grad(): - tq = -(preds[-1] - b.y).abs().mean(1) - loss = loss + args.lam_q * nn.functional.mse_loss(q, tq) - loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + if args.act: + act_state, metrics = act_train_step(model, act_state, b, opt, dev, args) + act_metrics.append(metrics) + else: + b = b.to(dev); opt.zero_grad() + if args.loss_mode == 'trace': + preds, q_logits = model.forward_trace( + b.x, b.edge_index, b.batch, args.n_sup, noise=model.sigma > 0) + q = q_logits[-1] + else: + preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) + loss = sum(l1(p, b.y) for p in preds) / len(preds) + with torch.no_grad(): + tq = -(preds[-1] - b.y).abs().mean(1) + loss = loss + args.lam_q * nn.functional.mse_loss(q, tq) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() sched.step() if (ep + 1) % 20 == 0 or ep == args.epochs - 1: - vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select) + if args.act: + vm, _ = evaluate_trace(model, val, dev, ymu, ysd, act_steps, adaptive=False) + else: + vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select) if sum(vm) < best_val: best_val = sum(vm) - tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select) - best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo} + if args.act: + tem, fixed_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=False) + tea, adaptive_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=True) + best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, + 'test_mae_adaptive': tea, 'fixed_steps': fixed_steps, + 'adaptive_steps': adaptive_steps} + else: + tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select) + best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo} best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} - print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True) + if args.act and act_metrics: + hm = sum(m['halted_frac'] for m in act_metrics) / len(act_metrics) + sm = sum(m['steps'] for m in act_metrics) / len(act_metrics) + print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]} halt={hm:.2f} train_steps={sm:.2f}", flush=True) + else: + print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True) - tag = f"rec_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}_s{args.seed}" + act_tag = f"_actfull{act_steps}_{args.halt_target}{args.halt_norm_threshold:g}_e{args.epochs}" if args.act else "" + loss_tag = f"_{args.loss_mode}" if (not args.act and args.loss_mode != 'last') else "" + view_tag = f"_{args.view}" if args.view != 'gin' else "" + tag = f"rec_rrog{view_tag}_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}{loss_tag}{act_tag}_s{args.seed}" rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1), - 'dev': dev, 'y_std_raw': ysd.tolist(), **best} - print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} " - f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)") + 'dev': dev, 'arch': 'rrog_once_agg_node_compute', 'y_std_raw': ysd.tolist(), **best} + if args.act: + print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} " + f"adaptive={[round(x,3) for x in best.get('test_mae_adaptive')]} " + f"steps={best.get('adaptive_steps'):.2f}/{best.get('fixed_steps'):.2f} " + f"@ep{best.get('ep')} ({rep['sec']}s)") + else: + print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} " + f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)") with open(os.path.join(OUT, f"{tag}.json"), 'w') as f: json.dump(rep, f, indent=2) torch.save({'state': best_state or model.state_dict(), 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup, - 'sigma': args.sigma, 'grad_mode': args.grad_mode}, + 'sigma': args.sigma, 'grad_mode': args.grad_mode, + 'agg_layers': args.agg_layers, 'compute_layers': args.compute_layers, + 'view': args.view, + 'loss_mode': args.loss_mode, + 'act': args.act, 'act_impl': 'persistent_recycle' if args.act else 'none', + 'halt_max_steps': act_steps, + 'halt_exploration_prob': args.halt_exploration_prob, + 'arch': 'rrog_once_agg_node_compute'}, 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt")) print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt")) |
