summaryrefslogtreecommitdiff
path: root/diag
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-29 12:04:47 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-29 12:04:47 -0500
commitc54ddb88b532be28ca3096e21de405d90163ecfa (patch)
tree3270ec9269dbee14ea915963f0d28e933303d5a7 /diag
parentd12722525fc010a3910b5152c72654a2ade5eac4 (diff)
Package full RRoG GNN project
Diffstat (limited to 'diag')
-rw-r--r--diag/aggregate.py5
-rw-r--r--diag/cin_color.py3
-rw-r--r--diag/esan_color.py3
-rw-r--r--diag/eval_rec_ttc.py113
-rw-r--r--diag/lyap.py24
-rw-r--r--diag/ptrm_color.py33
-rw-r--r--diag/run_archA.sh5
-rw-r--r--diag/run_archB.sh8
-rw-r--r--diag/run_color.sh13
-rw-r--r--diag/run_le.sh6
-rw-r--r--diag/run_pe.sh11
-rw-r--r--diag/run_pe2.sh6
-rw-r--r--diag/run_pe3.sh8
-rw-r--r--diag/run_pna.sh3
-rw-r--r--diag/run_rec.sh13
-rw-r--r--diag/run_seeds.sh7
-rw-r--r--diag/train_color.py142
-rw-r--r--diag/train_cycle.py11
-rw-r--r--diag/train_rec.py423
19 files changed, 663 insertions, 174 deletions
diff --git a/diag/aggregate.py b/diag/aggregate.py
index b0f737a..4ded30f 100644
--- a/diag/aggregate.py
+++ b/diag/aggregate.py
@@ -1,4 +1,4 @@
-"""Aggregate multi-seed coloring results -> mean+/-std per (grad_mode, pe, contract)."""
+"""Aggregate multi-seed coloring results -> mean+/-std per architecture/config."""
import glob, json
import numpy as np
from collections import defaultdict
@@ -22,7 +22,8 @@ def load(pat):
def key(d):
- return (d.get('conv', 'gin'), d.get('pe'), d.get('grad_mode'), 'ctr' if d.get('contract') else '-')
+ return (d.get('arch', 'legacy'), d.get('conv', 'gin'), d.get('pe'),
+ d.get('grad_mode'), 'ctr' if d.get('contract') else '-')
solve, le, ml = defaultdict(list), defaultdict(list), defaultdict(list)
diff --git a/diag/cin_color.py b/diag/cin_color.py
index 324215f..53ed158 100644
--- a/diag/cin_color.py
+++ b/diag/cin_color.py
@@ -131,7 +131,8 @@ def main():
print(f"[cin/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} "
f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)")
base = f"cin_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}"
- com = {'conv': 'cin', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed}
+ com = {'conv': 'cin', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False,
+ 'seed': args.seed, 'arch': 'rrog_once_agg_node_compute'}
json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'))
json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())}, open(os.path.join(OUT, f"le_color_{base}.json"), 'w'))
json.dump({**com, 'det': 1 - float(fails.mean()),
diff --git a/diag/esan_color.py b/diag/esan_color.py
index 7be050c..b283641 100644
--- a/diag/esan_color.py
+++ b/diag/esan_color.py
@@ -131,7 +131,8 @@ def main():
print(f"[esan/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} "
f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)")
base = f"esan_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}"
- com = {'conv': 'esan', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed}
+ com = {'conv': 'esan', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False,
+ 'seed': args.seed, 'arch': 'rrog_once_agg_node_compute'}
json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'))
json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())},
open(os.path.join(OUT, f"le_color_{base}.json"), 'w'))
diff --git a/diag/eval_rec_ttc.py b/diag/eval_rec_ttc.py
new file mode 100644
index 0000000..e0e789e
--- /dev/null
+++ b/diag/eval_rec_ttc.py
@@ -0,0 +1,113 @@
+"""Evaluate deterministic test-time compute for RRoG/TRM-on-GNN checkpoints.
+
+Loads a trained ``diag/train_rec.py`` checkpoint, overrides ``model.T`` at eval time,
+and reports raw MAE on ZINC cycle-count. This separates training-time T from
+test-time recursion depth.
+
+Run:
+ PYTHONPATH=/home/yurenh2/rrog python3 diag/eval_rec_ttc.py \
+ --ckpt runs/ckpt_rec_rrog_full_sig0.0_K1_bestq_T1_ns3_s0.pt --eval_Ts 0 1 2 3 5 8
+"""
+import argparse
+import json
+import os
+
+import torch
+
+from diag.train_rec import RecGIN, evaluate, evaluate_trace, loader
+from diag.train_cycle import prepare
+
+OUT = '/home/yurenh2/rrog/runs'
+
+
+def load_model(ckpt, dev):
+ ck = torch.load(ckpt, weights_only=False, map_location='cpu')
+ cfg = ck['cfg']
+ model = RecGIN(
+ cfg['n_atom'],
+ cfg['hidden'],
+ cfg['T'],
+ cfg['n_sup'],
+ sigma=0.0,
+ grad_mode=cfg.get('grad_mode', 'full'),
+ agg_layers=cfg.get('agg_layers', 5),
+ compute_layers=cfg.get('compute_layers', 2),
+ ).to(dev)
+ model.load_state_dict(ck['state'])
+ model.eval()
+ return model, ck
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--ckpt', required=True)
+ ap.add_argument('--eval_Ts', type=int, nargs='+', default=[0, 1, 2, 3, 5, 8])
+ ap.add_argument('--eval_n_sup', type=int, default=None)
+ ap.add_argument('--adaptive', action='store_true',
+ help='choose the first trace step with q_halt > 0, otherwise the last step')
+ ap.add_argument('--bs', type=int, default=256)
+ ap.add_argument('--split', choices=['val', 'test'], default='test')
+ args = ap.parse_args()
+
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ model, ck = load_model(args.ckpt, dev)
+ cfg = ck['cfg']
+ ymu, ysd = ck['ymu'], ck['ysd']
+
+ recs = prepare(args.split)
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+ ld = loader(recs, args.bs, False)
+
+ original_T = model.T
+ original_n_sup = model.n_sup
+ if args.eval_n_sup is not None:
+ model.n_sup = args.eval_n_sup
+ elif args.adaptive and cfg.get('act'):
+ model.n_sup = cfg.get('halt_max_steps', cfg['n_sup'])
+ eval_n_sup_used = model.n_sup
+
+ rows = []
+ print(f"ckpt={os.path.basename(args.ckpt)} train_T={cfg['T']} train_n_sup={cfg['n_sup']} split={args.split} adaptive={args.adaptive}")
+ print(f"{'T_eval':>6} {'n_sup':>6} {'mae_c5':>10} {'mae_c6':>10} {'sum':>10} {'avg_steps':>10}")
+ for T in args.eval_Ts:
+ model.T = T
+ if args.adaptive:
+ mae, avg_steps = evaluate_trace(model, ld, dev, ymu, ysd, model.n_sup, adaptive=True)
+ else:
+ mae, _ = evaluate(model, ld, dev, ymu, ysd, K=1, select='none')
+ avg_steps = float(model.n_sup)
+ row = {
+ 'T_eval': T,
+ 'n_sup_eval': model.n_sup,
+ 'adaptive': args.adaptive,
+ 'avg_steps': avg_steps,
+ 'mae': mae,
+ 'mae_sum': float(sum(mae)),
+ }
+ rows.append(row)
+ print(f"{T:6d} {model.n_sup:6d} {mae[0]:10.4f} {mae[1]:10.4f} {sum(mae):10.4f} {avg_steps:10.2f}")
+
+ model.T = original_T
+ model.n_sup = original_n_sup
+
+ os.makedirs(OUT, exist_ok=True)
+ base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
+ out = {
+ 'ckpt': args.ckpt,
+ 'split': args.split,
+ 'train_T': cfg['T'],
+ 'train_n_sup': cfg['n_sup'],
+ 'eval_n_sup': eval_n_sup_used,
+ 'adaptive': args.adaptive,
+ 'rows': rows,
+ }
+ suffix = "_adaptive" if args.adaptive else ""
+ fp = os.path.join(OUT, f"ttc_{base}_ns{out['eval_n_sup']}_{args.split}{suffix}.json")
+ with open(fp, 'w') as f:
+ json.dump(out, f, indent=2)
+ print("wrote", fp)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/diag/lyap.py b/diag/lyap.py
index 93b90bf..e676568 100644
--- a/diag/lyap.py
+++ b/diag/lyap.py
@@ -1,12 +1,12 @@
"""LE diagnostic for the recursive (TRM-ish) GNN — ports the flossing finding to graphs.
-Per-graph top Lyapunov exponent lambda1 of the recursion z <- block(z+h0), via Benettin
+Per-graph top Lyapunov exponent lambda1 of the edge-free recursion z <- block(z, ctx), via Benettin
power-iteration on a single tangent vector (JVP + renormalize, accumulate log-growth) over
the model's n_sup*T recursion steps. Bucket graphs by success/failure (rounded ring counts
exact) and compare lambda1 distributions + AUROC(fail | lambda1) — mirroring
plot_trm_lyap_hist.py. Hypothesis: failed graphs are MORE chaotic (higher lambda1).
-Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_full_..._s0.pt
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_rrog_full_..._s0.pt
"""
import argparse
import numpy as np
@@ -21,18 +21,19 @@ except Exception:
def build(ck, dev):
c = ck['cfg']
- m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode']).to(dev)
+ m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode'],
+ agg_layers=c.get('agg_layers', 1), compute_layers=c.get('compute_layers', 2)).to(dev)
m.load_state_dict(ck['state']); m.eval()
return m, c
def lyap1(model, x, ei, n_steps, dev, seed=0):
g = torch.Generator(device=dev).manual_seed(seed)
- h0 = model.emb(x).detach()
- z = torch.zeros_like(h0)
- v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
+ ctx = model.aggregate(x, ei).detach()
+ z = ctx.detach()
+ v = torch.randn(ctx.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
def step_fn(zz):
- return model.block(zz + h0, ei)
+ return model.block(zz, ctx)
lam = 0.0
for _ in range(n_steps):
z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
@@ -72,10 +73,13 @@ def main():
lams, fails = np.array(lams), np.array(fails)
s, f = lams[fails == 0], lams[fails == 1]
auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ sm, ss = (s.mean(), s.std()) if len(s) else (float('nan'), float('nan'))
+ fm, fs = (f.mean(), f.std()) if len(f) else (float('nan'), float('nan'))
+ sep = fm - sm if len(s) and len(f) else float('nan')
print(f"[{cfg['grad_mode']}] n={len(lams)} fail_rate={fails.mean():.2f} | "
- f"lambda1 SUCC mean {s.mean():+.4f} std {s.std():.4f} (n={len(s)}) | "
- f"FAIL mean {f.mean():+.4f} std {f.std():.4f} (n={len(f)}) | "
- f"sep(fail-succ)={f.mean()-s.mean() if len(s) and len(f) else float('nan'):+.4f} | "
+ f"lambda1 SUCC mean {sm:+.4f} std {ss:.4f} (n={len(s)}) | "
+ f"FAIL mean {fm:+.4f} std {fs:.4f} (n={len(f)}) | "
+ f"sep(fail-succ)={sep:+.4f} | "
f"AUROC(fail|lambda1)={auc:.3f} | mean_lambda1={lams.mean():+.4f}")
diff --git a/diag/ptrm_color.py b/diag/ptrm_color.py
index 4004297..b24097f 100644
--- a/diag/ptrm_color.py
+++ b/diag/ptrm_color.py
@@ -3,7 +3,7 @@
deterministic / pass@K (conflict-min, ground truth) / lambda-select (min lambda1) / random.
-Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_...pt
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_rrog_trm_gin_full_...pt
"""
import argparse, json, os
import numpy as np
@@ -18,20 +18,24 @@ OUT = '/home/yurenh2/rrog/runs'
def rollout(model, xin, ei, sigma, n_sup, T, dev, seed):
gen = torch.Generator(device=dev).manual_seed(seed)
- h0 = model.lin_in(xin)
- z = torch.zeros_like(h0)
- v = torch.randn(h0.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12)
- def step(zz):
- return model.block(zz + h0, ei)
+ ctx = model.aggregate(xin, ei)
+ y, z = ctx, torch.zeros_like(ctx)
+ state = torch.cat([y, z], dim=-1)
+ v = torch.randn(state.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12)
+ def step(ss):
+ yy, zz = ss.chunk(2, dim=-1)
+ yy, zz = model.recurse(yy, zz, ctx, noise=False)
+ return torch.cat([yy, zz], dim=-1)
lam = 0.0
for _ in range(n_sup * T):
- z_det, Jv = torch.autograd.functional.jvp(step, z, v)
+ state_det, Jv = torch.autograd.functional.jvp(step, state, v)
nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
- z = z_det.detach()
+ state = state_det.detach()
if sigma > 0:
- z = z + sigma * torch.randn(z.shape, generator=gen, device=dev)
- lam /= (n_sup * T)
- col = model.head(z).argmax(-1)
+ state = state + sigma * torch.randn(state.shape, generator=gen, device=dev)
+ lam /= max(n_sup * T, 1)
+ y, _ = state.chunk(2, dim=-1)
+ col = model.head(y).argmax(-1)
conf = (col[ei[0]] == col[ei[1]]).sum().item() // 2
return conf, lam
@@ -47,7 +51,11 @@ def main():
ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg']
deg = torch.tensor(c['deg']) if c.get('deg') else None
model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'],
- grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev)
+ grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg,
+ agg_layers=c.get('agg_layers', 1),
+ compute_layers=c.get('compute_layers', 2),
+ compute=(c.get('compute') if c.get('compute') == 'trm' else 'trm'),
+ attn_heads=c.get('attn_heads', 4)).to(dev)
model.load_state_dict(ck['state']); model.eval()
nsup, T = c['n_sup'], c['T']
te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), c.get('pe', 'none'), c.get('rwse_k', 16))
@@ -56,6 +64,7 @@ def main():
det = sum(rollout(model, r['xin'].to(dev), r['edge_index'].to(dev), 0.0, nsup, T, dev, 0)[0] == 0
for r in te) / n
out = {'conv': c.get('conv', 'gin'), 'pe': c.get('pe', 'none'), 'seed': c.get('seed'),
+ 'arch': c.get('arch', 'legacy'),
'grad_mode': c['grad_mode'], 'contract': c.get('contract', False), 'det': det, 'sigmas': {}}
print(f"[pe={out['pe']} s{out['seed']}] deterministic solve_rate = {det:.3f} (n={n}, K={args.K})")
print(f"{'sigma':>6} {'pass@K':>8} {'lam-sel':>8} {'random':>8} {'perRoll':>8} {'AUROC(s|-lam)':>14}")
diff --git a/diag/run_archA.sh b/diag/run_archA.sh
index 5385152..6cc6042 100644
--- a/diag/run_archA.sh
+++ b/diag/run_archA.sh
@@ -1,16 +1,15 @@
#!/usr/bin/env bash
-# Conv axis (pe=none): gcn, sage, gat. 5 seeds, train+LE+PTRM. (pin GPU at launch.)
+# One-shot aggregation axis (pe=none): gcn, sage, gat. 5 seeds, train+LE.
set -uo pipefail
cd /home/yurenh2/rrog
export PYTHONPATH=/home/yurenh2/rrog
echo "A start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
for s in 0 1 2 3 4; do
for conv in gcn sage gat; do
- ck=runs/ckpt_color_${conv}_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
+ ck=runs/ckpt_color_rrog_trm_${conv}_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
echo "== A s$s conv=$conv =="
python3 diag/train_color.py --mode train --conv "$conv" --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train $conv s$s"
python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le $conv s$s"
- python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $conv s$s"
done
done
echo "doneA=$(date -Is)"
diff --git a/diag/run_archB.sh b/diag/run_archB.sh
index 317638f..ddc6069 100644
--- a/diag/run_archB.sh
+++ b/diag/run_archB.sh
@@ -1,21 +1,19 @@
#!/usr/bin/env bash
-# GPS transformer backbone + feature axis (gin + lappe / all). 5 seeds. (pin GPU at launch.)
+# One-shot GPS encoder + feature axis (gin + lappe / all). 5 seeds. (pin GPU at launch.)
set -uo pipefail
cd /home/yurenh2/rrog
export PYTHONPATH=/home/yurenh2/rrog
echo "B start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
for s in 0 1 2 3 4; do
- ck=runs/ckpt_color_gps_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
+ ck=runs/ckpt_color_rrog_trm_gps_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
echo "== B s$s conv=gps =="
python3 diag/train_color.py --mode train --conv gps --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train gps s$s"
python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le gps s$s"
- python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm gps s$s"
for pe in lappe all; do
- ck2=runs/ckpt_color_gin_full_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt
+ ck2=runs/ckpt_color_rrog_trm_gin_full_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt
echo "== B s$s gin pe=$pe =="
python3 diag/train_color.py --mode train --conv gin --pe "$pe" --p 0.2 --epochs 150 --seed "$s" || echo "!! train $pe s$s"
python3 diag/train_color.py --mode le --ckpt "$ck2" || echo "!! le $pe s$s"
- python3 diag/ptrm_color.py --ckpt "$ck2" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $pe s$s"
done
done
echo "doneB=$(date -Is)"
diff --git a/diag/run_color.sh b/diag/run_color.sh
index ad3c406..82c522a 100644
--- a/diag/run_color.sh
+++ b/diag/run_color.sh
@@ -1,15 +1,12 @@
#!/usr/bin/env bash
-# Step-2 (TRM regime, large output): graph 3-coloring, TRM full vs HRM 1-step + LE diagnostic.
+# RRoG/TRM-on-GNN: graph 3-coloring, deterministic T=1 lower bound vs T=3 extra compute.
set -uo pipefail
cd /home/yurenh2/rrog
export PYTHONPATH=/home/yurenh2/rrog
echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
-for gm in full 1step; do
- echo "===== train $gm ====="
- python3 diag/train_color.py --mode train --grad_mode "$gm" --p 0.2 --epochs 150 --seed 0 \
- || echo "!! train $gm failed"
+for T in 1 3; do
+ echo "===== train T=$T ====="
+ python3 diag/train_color.py --mode train --grad_mode full --T "$T" --p 0.2 --epochs 150 --seed 0 \
+ || echo "!! train T=$T failed"
done
-echo "===== LE diagnostic (lambda1: solved vs unsolved) ====="
-python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_full_n50_k3_p0.2_T3_ns3_s0.pt || echo "!! le full failed"
-python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_1step_n50_k3_p0.2_T3_ns3_s0.pt || echo "!! le 1step failed"
echo "done=$(date -Is)"
diff --git a/diag/run_le.sh b/diag/run_le.sh
index 8fc31ea..07d6a63 100644
--- a/diag/run_le.sh
+++ b/diag/run_le.sh
@@ -1,5 +1,5 @@
#!/usr/bin/env bash
-# Step-2(iii): TRM-ish GIN full-recursion vs 1-step-gradient, then LE diagnostic on each.
+# RRoG/TRM-on-GNN GIN full-recursion vs 1-step-gradient, then LE diagnostic on each.
set -uo pipefail
cd /home/yurenh2/rrog
export PYTHONPATH=/home/yurenh2/rrog
@@ -7,6 +7,6 @@ echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 --epochs 200 --seed 0 || echo "!! train full failed"
python3 diag/train_rec.py --grad_mode 1step --sigma 0 --K 1 --epochs 200 --seed 0 || echo "!! train 1step failed"
echo "===== LE diagnostic (lambda1: success vs failure) ====="
-python3 diag/lyap.py --ckpt runs/ckpt_rec_full_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap full failed"
-python3 diag/lyap.py --ckpt runs/ckpt_rec_1step_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap 1step failed"
+python3 diag/lyap.py --ckpt runs/ckpt_rec_rrog_full_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap full failed"
+python3 diag/lyap.py --ckpt runs/ckpt_rec_rrog_1step_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap 1step failed"
echo "done=$(date -Is)"
diff --git a/diag/run_pe.sh b/diag/run_pe.sh
index 8350160..1e587ce 100644
--- a/diag/run_pe.sh
+++ b/diag/run_pe.sh
@@ -1,6 +1,5 @@
#!/usr/bin/env bash
-# Roadmap #1: RRoG on PE-augmented backbone. GIN vs GIN+RWSE on coloring:
-# does RRoG-noise add headroom on top of static structural encoding, or is it redundant?
+# RRoG on PE-augmented one-shot GIN encoder. GIN vs GIN+RWSE on coloring.
set -uo pipefail
cd /home/yurenh2/rrog
export PYTHONPATH=/home/yurenh2/rrog
@@ -12,13 +11,7 @@ for pe in none rwse; do
done
echo "===== LE (full, both pe) ====="
for pe in none rwse; do
- python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \
+ python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_rrog_trm_gin_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \
|| echo "!! le $pe failed"
done
-echo "===== PTRM noise + lambda-select (both pe) ====="
-for pe in none rwse; do
- echo "--- pe=$pe ---"
- python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \
- --K 16 --n_graphs 150 --sigmas 0.1 0.2 0.4 || echo "!! ptrm $pe failed"
-done
echo "done=$(date -Is)"
diff --git a/diag/run_pe2.sh b/diag/run_pe2.sh
index db7bd8a..dc9f88b 100644
--- a/diag/run_pe2.sh
+++ b/diag/run_pe2.sh
@@ -1,17 +1,15 @@
#!/usr/bin/env bash
-# Roadmap #2: RRoG on a GSN-style motif backbone (per-node K3/wedge substructure counts).
-# full-recursion, 5 seeds; train + LE + PTRM(sigma=0.2). Aggregate vs none/rwse.
+# RRoG on a GSN-style motif-augmented one-shot encoder. Full recursion, 5 seeds; train + LE.
set -uo pipefail
cd /home/yurenh2/rrog
export PYTHONPATH=/home/yurenh2/rrog
echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
for s in 0 1 2 3 4; do
- ck=runs/ckpt_color_full_gsn_n50_k3_p0.2_T3_ns3_s${s}.pt
+ ck=runs/ckpt_color_rrog_trm_gin_full_gsn_n50_k3_p0.2_T3_ns3_s${s}.pt
echo "===== seed=$s full gsn ====="
python3 diag/train_color.py --mode train --grad_mode full --pe gsn --p 0.2 --epochs 150 --seed "$s" \
|| echo "!! train gsn s$s failed"
python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le gsn s$s failed"
- python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm gsn s$s failed"
done
echo "===== AGGREGATE (all pe: none/rwse/gsn) ====="
python3 diag/aggregate.py
diff --git a/diag/run_pe3.sh b/diag/run_pe3.sh
index a978393..2a6744b 100644
--- a/diag/run_pe3.sh
+++ b/diag/run_pe3.sh
@@ -1,22 +1,20 @@
#!/usr/bin/env bash
-# Roadmap #3 (subgraph/ego features) + #4 (IGNN-style forced contraction), 5 seeds.
+# Subgraph/ego features + forced contraction, 5 seeds. Deterministic train + LE.
# #3: full --pe sub. #4: full --pe none --contract (vs free full/none baseline).
set -uo pipefail
cd /home/yurenh2/rrog
export PYTHONPATH=/home/yurenh2/rrog
echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
for s in 0 1 2 3 4; do
- ck=runs/ckpt_color_full_sub_n50_k3_p0.2_T3_ns3_s${s}.pt
+ ck=runs/ckpt_color_rrog_trm_gin_full_sub_n50_k3_p0.2_T3_ns3_s${s}.pt
echo "===== seed=$s #3 full sub ====="
python3 diag/train_color.py --mode train --grad_mode full --pe sub --p 0.2 --epochs 150 --seed "$s" || echo "!! train sub s$s failed"
python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le sub s$s failed"
- python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm sub s$s failed"
- ck2=runs/ckpt_color_full_none_ctr_n50_k3_p0.2_T3_ns3_s${s}.pt
+ ck2=runs/ckpt_color_rrog_trm_gin_full_none_ctr_n50_k3_p0.2_T3_ns3_s${s}.pt
echo "===== seed=$s #4 full none --contract ====="
python3 diag/train_color.py --mode train --grad_mode full --pe none --contract --p 0.2 --epochs 150 --seed "$s" || echo "!! train ctr s$s failed"
python3 diag/train_color.py --mode le --ckpt "$ck2" || echo "!! le ctr s$s failed"
- python3 diag/ptrm_color.py --ckpt "$ck2" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm ctr s$s failed"
done
echo "===== AGGREGATE ====="
python3 diag/aggregate.py
diff --git a/diag/run_pna.sh b/diag/run_pna.sh
index 897315d..157d8ac 100644
--- a/diag/run_pna.sh
+++ b/diag/run_pna.sh
@@ -4,10 +4,9 @@ cd /home/yurenh2/rrog
export PYTHONPATH=/home/yurenh2/rrog
echo "PNA start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
for s in 0 1 2 3 4; do
- ck=runs/ckpt_color_pna_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
+ ck=runs/ckpt_color_rrog_trm_pna_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
echo "== pna s$s =="
python3 diag/train_color.py --mode train --conv pna --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train pna s$s"
python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le pna s$s"
- python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm pna s$s"
done
echo "donePNA=$(date -Is)"
diff --git a/diag/run_rec.sh b/diag/run_rec.sh
index 632498d..c436c4e 100644
--- a/diag/run_rec.sh
+++ b/diag/run_rec.sh
@@ -1,13 +1,12 @@
#!/usr/bin/env bash
-# Step-2: recursive GNN + full PTRM on ZINC ring-counting. Does per-step noise + best-Q@K
-# selection break the 1-WL counting ceiling that input-RNI couldn't? Averaging vs selection.
+# RRoG/TRM-on-GNN on ZINC ring-counting. Deterministic view-only vs recursive compute.
set -uo pipefail
cd /home/yurenh2/rrog
export PYTHONPATH=/home/yurenh2/rrog
echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
-run() { echo "===== $* ====="; python3 diag/train_rec.py "$@" --epochs 200 --seed 0 || echo "!! FAILED $*"; }
-run --sigma 0 --K 1 --select bestq # deterministic recursive baseline (~1-WL ceiling)
-run --sigma 0.1 --K 8 --select none # averaging over rollouts (RNI-style null)
-run --sigma 0.1 --K 8 --select bestq # PTRM-proper: per-step noise + best-Q@K selection
-run --sigma 0.2 --K 16 --select bestq # scaled noise + rollouts
+run() { echo "===== $* ====="; python3 diag/train_rec.py "$@" --epochs 40 --hidden 64 --bs 256 --seed 0 || echo "!! FAILED $*"; }
+run --grad_mode full --T 0 --n_sup 3 --sigma 0 --K 1 --select none
+run --grad_mode full --T 1 --n_sup 3 --sigma 0 --K 1 --select none
+run --grad_mode full --T 3 --n_sup 3 --sigma 0 --K 1 --select none
+run --grad_mode full --T 1 --n_sup 3 --act --halt_max_steps 8 --halt_target binary --sigma 0 --K 1 --select none
echo "done=$(date -Is)"
diff --git a/diag/run_seeds.sh b/diag/run_seeds.sh
index 77f22c8..85b67d1 100644
--- a/diag/run_seeds.sh
+++ b/diag/run_seeds.sh
@@ -1,5 +1,5 @@
#!/usr/bin/env bash
-# Harden coloring results with 5 seeds: full-vs-1step solve, LE AUROC, PTRM(sigma=0.2) for none/rwse.
+# Harden corrected RRoG coloring results with 5 seeds: full-vs-1step solve + LE AUROC.
set -uo pipefail
cd /home/yurenh2/rrog
export PYTHONPATH=/home/yurenh2/rrog
@@ -7,14 +7,11 @@ echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
for s in 0 1 2 3 4; do
for cfg in "full none" "full rwse" "1step none"; do
set -- $cfg; gm=$1; pe=$2
- ck=runs/ckpt_color_${gm}_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt
+ ck=runs/ckpt_color_rrog_trm_gin_${gm}_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt
echo "===== seed=$s $gm $pe ====="
python3 diag/train_color.py --mode train --grad_mode "$gm" --pe "$pe" --p 0.2 --epochs 150 --seed "$s" \
|| echo "!! train $gm $pe s$s failed"
python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le $gm $pe s$s failed"
- if [ "$gm" = "full" ]; then
- python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $pe s$s failed"
- fi
done
done
echo "===== AGGREGATE ====="
diff --git a/diag/train_color.py b/diag/train_color.py
index 36f8496..b9d0c23 100644
--- a/diag/train_color.py
+++ b/diag/train_color.py
@@ -1,7 +1,11 @@
-"""Recursive (TRM-ish) GNN graph 3-coloring with swappable BACKBONE for the RRoG roadmap.
+"""RRoG/TRM-on-GNN graph 3-coloring.
---conv gin|gcn|sage|gat|gps : message-passing operator (gps = GraphGPS local MPNN + global
- attention = TRM's original transformer backbone, on the graph).
+The graph is encoded once into a fixed per-node context. Recursion then refines hidden state
+with a shared compute block that never reads edge_index. This is the RRoG split:
+the GNN encoder supplies the view/context x, TRM-style recurrence supplies computation.
+
+--conv gin|gcn|sage|gat|gps : message-passing operator used only by the one-shot encoder
+ (gps = GraphGPS local MPNN + global attention).
--pe none|rwse|gsn|sub|lappe|all : input structural features (random sym-break [+ encoding]).
--contract : reverse-flossing lambda-penalty during training (force contraction; roadmap #4).
--grad_mode full|1step : TRM full recursion vs HRM 1-step gradient.
@@ -142,48 +146,83 @@ def make_conv(conv, hidden, deg=None):
class RecGINColor(nn.Module):
- def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full', sigma=0.0, conv='gin', deg=None):
+ def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full',
+ sigma=0.0, conv='gin', deg=None, agg_layers=4, compute_layers=None,
+ compute='trm', attn_heads=4):
super().__init__()
self.conv_type = conv
+ self.agg_layers = agg_layers
+ self.compute_layers = compute_layers or inner
+ self.compute = compute
+ self.attn_heads = attn_heads
self.lin_in = nn.Linear(in_dim, hidden)
- self.convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(inner)])
- self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)])
+ self.agg_convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(agg_layers)])
+ self.agg_bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(agg_layers)])
+ if compute not in ('trm',):
+ raise ValueError(compute)
+ core = []
+ d = hidden
+ for _ in range(self.compute_layers - 1):
+ core += [nn.Linear(d, hidden), nn.GELU()]
+ d = hidden
+ core.append(nn.Linear(d, hidden))
+ self.core_norm = nn.LayerNorm(hidden)
+ self.core = nn.Sequential(*core)
+ nn.init.zeros_(self.core[-1].weight)
+ nn.init.zeros_(self.core[-1].bias)
self.head = nn.Linear(hidden, k)
self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma
- def block(self, z, ei, batch=None):
+ def aggregate(self, xin, ei, batch=None):
if self.conv_type == 'gps' and batch is None:
- batch = z.new_zeros(z.size(0), dtype=torch.long)
- for conv, bn in zip(self.convs, self.bns):
- z = conv(z, ei, batch) if self.conv_type == 'gps' else conv(z, ei)
- z = bn(z).relu()
- return z
-
- def _inner(self, z, h0, ei, noise, batch):
- z = self.block(z + h0, ei, batch)
+ batch = xin.new_zeros(xin.size(0), dtype=torch.long)
+ h = self.lin_in(xin)
+ for conv, bn in zip(self.agg_convs, self.agg_bns):
+ h = conv(h, ei, batch) if self.conv_type == 'gps' else conv(h, ei)
+ h = bn(h).relu()
+ return h
+
+ def core_step(self, combined, state):
+ """Shared TRM compute core. Deliberately edge-free."""
+ return state + self.core(self.core_norm(combined))
+
+ def _z_step(self, y, z, ctx, noise):
+ z = self.core_step(ctx + y + z, z)
if noise and self.sigma > 0:
z = z + self.sigma * torch.randn_like(z)
return z
- def recurse(self, z, h0, ei, noise, batch, one_step=False):
+ def _y_step(self, y, z, noise):
+ y = self.core_step(y + z, y)
+ if noise and self.sigma > 0:
+ y = y + self.sigma * torch.randn_like(y)
+ return y
+
+ def recurse(self, y, z, ctx, noise, one_step=False):
+ if self.T == 0:
+ return y, z
if one_step:
with torch.no_grad():
for _ in range(self.T - 1):
- z = self._inner(z, h0, ei, noise, batch)
+ z = self._z_step(y, z, ctx, noise)
z = z.detach()
- return self._inner(z, h0, ei, noise, batch)
+ z = self._z_step(y, z, ctx, noise)
+ y = self._y_step(y, z, noise)
+ return y, z
for _ in range(self.T):
- z = self._inner(z, h0, ei, noise, batch)
- return z
+ z = self._z_step(y, z, ctx, noise)
+ y = self._y_step(y, z, noise)
+ return y, z
def forward(self, xin, ei, batch=None, noise=False):
- h0 = self.lin_in(xin)
- z = torch.zeros_like(h0)
+ ctx = self.aggregate(xin, ei, batch)
+ y = ctx
+ z = torch.zeros_like(ctx)
outs = []
for s in range(self.n_sup):
- z = self.recurse(z, h0, ei, noise, batch, one_step=(self.grad_mode == '1step'))
- outs.append(self.head(z))
- z = z.detach()
+ y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step'))
+ outs.append(self.head(y))
+ y, z = y.detach(), z.detach()
return outs
@@ -206,15 +245,17 @@ def solve_stats(model, recs, dev, sample=None):
def lyap1(model, xin, ei, n_steps, dev, seed=0):
g = torch.Generator(device=dev).manual_seed(seed)
- h0 = model.lin_in(xin).detach()
- z = torch.zeros_like(h0)
- v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
- def step_fn(zz):
- return model.block(zz + h0, ei)
+ ctx = model.aggregate(xin, ei).detach()
+ state = torch.cat([ctx, torch.zeros_like(ctx)], dim=-1).detach()
+ v = torch.randn(state.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
+ def step_fn(ss):
+ y, z = ss.chunk(2, dim=-1)
+ y, z = model.recurse(y, z, ctx, noise=False)
+ return torch.cat([y, z], dim=-1)
lam = 0.0
for _ in range(n_steps):
- z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
- z = z_next.detach(); nv = Jv.norm()
+ state_next, Jv = torch.autograd.functional.jvp(step_fn, state, v)
+ state = state_next.detach(); nv = Jv.norm()
lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
return lam / n_steps
@@ -246,11 +287,16 @@ def run_le(model, recs, dev, n_steps, n_graphs=300):
def lyap_penalty(model, x, ei, batch, target=-0.5):
- h0 = model.lin_in(x)
+ ctx = model.aggregate(x, ei, batch)
with torch.no_grad():
- zr = model.recurse(torch.zeros_like(h0), h0.detach(), ei, False, batch)
- v = torch.randn_like(zr); v = v / (v.norm() + 1e-12)
- _, Jv = torch.autograd.functional.jvp(lambda zz: model.block(zz + h0, ei, batch), zr, v, create_graph=True)
+ yr, zr = model.recurse(ctx.detach(), torch.zeros_like(ctx).detach(), ctx.detach(), False)
+ state = torch.cat([yr, zr], dim=-1)
+ v = torch.randn_like(state); v = v / (v.norm() + 1e-12)
+ def step_fn(ss):
+ y, z = ss.chunk(2, dim=-1)
+ y, z = model.recurse(y, z, ctx, noise=False)
+ return torch.cat([y, z], dim=-1)
+ _, Jv = torch.autograd.functional.jvp(step_fn, state, v, create_graph=True)
return (torch.log(Jv.norm() + 1e-12) - target) ** 2
@@ -267,6 +313,10 @@ def main():
ap.add_argument('--p', type=float, default=0.2); ap.add_argument('--r', type=int, default=8)
ap.add_argument('--hidden', type=int, default=128); ap.add_argument('--T', type=int, default=3)
ap.add_argument('--n_sup', type=int, default=3); ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--agg_layers', type=int, default=4)
+ ap.add_argument('--compute_layers', type=int, default=2)
+ ap.add_argument('--compute', choices=['trm'], default='trm')
+ ap.add_argument('--attn_heads', type=int, default=4)
ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--bs', type=int, default=32)
ap.add_argument('--seed', type=int, default=0)
args = ap.parse_args()
@@ -280,13 +330,18 @@ def main():
c.get('pe', 'none'), c.get('rwse_k', 16))
deg = torch.tensor(c['deg']) if c.get('deg') else None
model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'],
- grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev)
+ grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg,
+ agg_layers=c.get('agg_layers', 1),
+ compute_layers=c.get('compute_layers', 2),
+ compute=(c.get('compute') if c.get('compute') == 'trm' else 'trm'),
+ attn_heads=c.get('attn_heads', 4)).to(dev)
model.load_state_dict(ck['state']); model.eval()
res = run_le(model, te, dev, c['n_sup'] * c['T'])
base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
with open(os.path.join(OUT, f"le_{base}.json"), 'w') as fjs:
json.dump({'conv': c.get('conv', 'gin'), 'grad_mode': c['grad_mode'], 'pe': c.get('pe', 'none'),
- 'contract': c.get('contract', False), 'seed': c.get('seed'), **res}, fjs, indent=2)
+ 'contract': c.get('contract', False), 'seed': c.get('seed'),
+ 'arch': c.get('arch', 'legacy'), **res}, fjs, indent=2)
return
te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), args.pe, args.rwse_k)
@@ -296,7 +351,9 @@ def main():
trl = DataLoader(data, batch_size=args.bs, shuffle=True, drop_last=True)
deg = deg_hist(tr) if args.conv == 'pna' else None
model = RecGINColor(in_dim, args.hidden, args.k, args.T, args.n_sup,
- grad_mode=args.grad_mode, conv=args.conv, deg=deg).to(dev)
+ grad_mode=args.grad_mode, conv=args.conv, deg=deg,
+ agg_layers=args.agg_layers, compute_layers=args.compute_layers,
+ compute=args.compute, attn_heads=args.attn_heads).to(dev)
opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
@@ -329,15 +386,18 @@ def main():
print(f"ep{ep+1} solve_rate={sr:.3f} mean_conflicts={mc:.2f}", flush=True)
sfx = ('_ctr' if args.contract else '')
- tag = f"color_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}"
+ tag = f"color_rrog_{args.compute}_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}"
rep = {'task': 'graph3coloring', 'tag': tag, **vars(args), 'in_dim': in_dim,
- 'sec': round(time.time() - t0, 1), **best}
+ 'arch': 'rrog_once_agg_hidden_compute', 'sec': round(time.time() - t0, 1), **best}
print(f"[{tag}] best solve_rate={best.get('solve_rate')} @ep{best.get('ep')} ({rep['sec']}s)")
with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
json.dump(rep, f, indent=2)
torch.save({'state': best_state, 'cfg': {'in_dim': in_dim, 'hidden': args.hidden, 'k': args.k,
'T': args.T, 'n_sup': args.n_sup, 'grad_mode': args.grad_mode, 'pe': args.pe,
'rwse_k': args.rwse_k, 'contract': args.contract, 'conv': args.conv, 'seed': args.seed,
+ 'agg_layers': args.agg_layers, 'compute_layers': args.compute_layers,
+ 'compute': args.compute, 'attn_heads': args.attn_heads,
+ 'arch': 'rrog_once_agg_hidden_compute',
'deg': (deg.tolist() if deg is not None else None)}},
os.path.join(OUT, f"ckpt_{tag}.pt"))
print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))
diff --git a/diag/train_cycle.py b/diag/train_cycle.py
index d2342f3..598e349 100644
--- a/diag/train_cycle.py
+++ b/diag/train_cycle.py
@@ -24,9 +24,14 @@ from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx
from torch_geometric.nn import GINConv, GCNConv, global_add_pool
-ROOT = '/home/yurenh2/rrog/data/zinc'
-CACHE = '/home/yurenh2/rrog/data/cycle_cache'
-OUT = '/home/yurenh2/rrog/runs'
+PROJECT_ROOT = os.environ.get(
+ 'RROG_ROOT',
+ os.path.abspath(os.path.join(os.path.dirname(__file__), '..')),
+)
+DATA_ROOT = os.environ.get('RROG_DATA_DIR', os.path.join(PROJECT_ROOT, 'data'))
+OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs'))
+ROOT = os.path.join(DATA_ROOT, 'zinc')
+CACHE = os.path.join(DATA_ROOT, 'cycle_cache')
RWSE_K = 16
diff --git a/diag/train_rec.py b/diag/train_rec.py
index 9866f28..9db7eb1 100644
--- a/diag/train_rec.py
+++ b/diag/train_rec.py
@@ -1,11 +1,10 @@
-"""Step-2: Recursive (TRM-ish) GNN on ZINC ring-counting + optional PTRM noise/selection.
+"""Step-2: RRoG/TRM-on-GNN for ZINC ring-counting.
-Recurrent shared-weight GIN block, deep-supervised over n_sup steps (TRM-style: carry latent
-detached between steps). --grad_mode controls the LAST supervision step's recursion:
+The graph is encoded once with a GIN encoder. A shared edge-free node-wise compute block then
+refines hidden state over n_sup*T recurrent steps (TRM-style: carry latent detached between
+deep-supervision steps). --grad_mode controls the LAST supervision step's recursion:
full : backprop through all T inner recursions (TRM)
1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient)
-Optional per-step Gaussian noise (sigma) + K stochastic rollouts selected by a value head
-(best-Q@K) for the PTRM experiments. Saves a checkpoint for the LE diagnostic (diag/lyap.py).
Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1
"""
@@ -14,66 +13,218 @@ import numpy as np
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
-from torch_geometric.data import Data
-from torch_geometric.nn import GINConv, global_add_pool
+from torch_geometric.data import Batch, Data
+from torch_geometric.nn import (
+ APPNP,
+ ARMAConv,
+ ChebConv,
+ FiLMConv,
+ GATv2Conv,
+ GCNConv,
+ GENConv,
+ GINEConv,
+ GINConv,
+ GraphConv,
+ MFConv,
+ PNAConv,
+ ResGatedGraphConv,
+ SAGEConv,
+ SGConv,
+ TAGConv,
+ TransformerConv,
+ global_add_pool,
+)
+from torch_geometric.utils import degree
from diag.train_cycle import prepare
-OUT = '/home/yurenh2/rrog/runs'
+PROJECT_ROOT = os.environ.get(
+ 'RROG_ROOT',
+ os.path.abspath(os.path.join(os.path.dirname(__file__), '..')),
+)
+OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs'))
+SUPPORTED_VIEWS = [
+ 'gin', 'gine', 'gcn', 'graphsage', 'gatv2', 'graphconv', 'transformer', 'pna',
+ 'gen', 'film', 'resgated', 'tag', 'sgc', 'cheb', 'arma', 'mf', 'appnp',
+]
-def loader(recs, bs, shuffle, drop_last=False):
- data = [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2),
+def data_list(recs):
+ return [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2),
num_nodes=r['x'].numel()) for r in recs]
+
+
+def loader(recs, bs, shuffle, drop_last=False):
+ data = recs if recs and isinstance(recs[0], Data) else data_list(recs)
return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+def degree_histogram(data):
+ max_degree = 0
+ degs = []
+ for graph in data:
+ deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long)
+ degs.append(deg)
+ if deg.numel():
+ max_degree = max(max_degree, int(deg.max().item()))
+ hist = torch.zeros(max_degree + 1, dtype=torch.long)
+ for deg in degs:
+ hist += torch.bincount(deg, minlength=hist.numel())
+ return hist
+
+
+def make_view_layer(view, hidden, deg):
+ if view == 'gin':
+ return GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ if view == 'gine':
+ return GINEConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)),
+ train_eps=True, edge_dim=hidden)
+ if view == 'gcn':
+ return GCNConv(hidden, hidden)
+ if view == 'graphsage':
+ return SAGEConv(hidden, hidden)
+ if view == 'gatv2':
+ return GATv2Conv(hidden, hidden, heads=4, concat=False)
+ if view == 'graphconv':
+ return GraphConv(hidden, hidden)
+ if view == 'transformer':
+ return TransformerConv(hidden, hidden, heads=4, concat=False)
+ if view == 'pna':
+ if deg is None:
+ raise ValueError('PNA view requires a training-set degree histogram')
+ return PNAConv(
+ hidden, hidden,
+ aggregators=['mean', 'min', 'max', 'std'],
+ scalers=['identity', 'amplification', 'attenuation'],
+ deg=deg,
+ )
+ if view == 'gen':
+ return GENConv(hidden, hidden)
+ if view == 'film':
+ return FiLMConv(hidden, hidden)
+ if view == 'resgated':
+ return ResGatedGraphConv(hidden, hidden)
+ if view == 'tag':
+ return TAGConv(hidden, hidden, K=3)
+ if view == 'sgc':
+ return SGConv(hidden, hidden, K=2, cached=False)
+ if view == 'cheb':
+ return ChebConv(hidden, hidden, K=3)
+ if view == 'arma':
+ return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2)
+ if view == 'mf':
+ return MFConv(hidden, hidden)
+ if view == 'appnp':
+ return APPNP(K=5, alpha=0.1)
+ raise ValueError(f'unsupported view: {view}')
+
+
class RecGIN(nn.Module):
- def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, grad_mode='full'):
+ def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2,
+ grad_mode='full', agg_layers=5, compute_layers=None, view='gin', deg=None):
super().__init__()
+ self.view = view
+ self.agg_layers = agg_layers
+ self.compute_layers = compute_layers or inner
self.emb = nn.Embedding(n_atom, hidden)
- self.convs = nn.ModuleList([GINConv(nn.Sequential(
- nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
- for _ in range(inner)])
- self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)])
+ self.edge_emb = nn.Embedding(1, hidden) if view == 'gine' else None
+ self.agg_convs = nn.ModuleList()
+ for _ in range(agg_layers):
+ self.agg_convs.append(make_view_layer(view, hidden, deg))
+ self.agg_bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(agg_layers)])
+ core = []
+ d = hidden
+ for _ in range(self.compute_layers - 1):
+ core += [nn.Linear(d, hidden), nn.GELU()]
+ d = hidden
+ core.append(nn.Linear(d, hidden))
+ self.core_norm = nn.LayerNorm(hidden)
+ self.core = nn.Sequential(*core)
+ nn.init.zeros_(self.core[-1].weight)
+ nn.init.zeros_(self.core[-1].bias)
self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
+ with torch.no_grad():
+ self.qhead[-1].weight.zero_()
+ self.qhead[-1].bias.fill_(-5.0)
self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode
- def block(self, z, ei):
- for conv, bn in zip(self.convs, self.bns):
- z = bn(conv(z, ei)).relu()
- return z
+ def aggregate(self, x, ei):
+ h = self.emb(x)
+ for conv, bn in zip(self.agg_convs, self.agg_bns):
+ if self.view == 'gine':
+ edge_attr = self.edge_emb(torch.zeros(ei.size(1), dtype=torch.long, device=ei.device))
+ h = bn(conv(h, ei, edge_attr)).relu()
+ else:
+ h = bn(conv(h, ei)).relu()
+ return h
+
+ def core_step(self, combined, state):
+ """Shared TRM compute core. Deliberately edge-free."""
+ return state + self.core(self.core_norm(combined))
- def _inner(self, z, h0, ei, noise):
- z = self.block(z + h0, ei)
+ def _z_step(self, y, z, ctx, noise):
+ z = self.core_step(ctx + y + z, z)
if noise and self.sigma > 0:
z = z + self.sigma * torch.randn_like(z)
return z
- def recurse(self, z, h0, ei, noise, one_step=False):
+ def _y_step(self, y, z, noise):
+ y = self.core_step(y + z, y)
+ if noise and self.sigma > 0:
+ y = y + self.sigma * torch.randn_like(y)
+ return y
+
+ def recurse(self, y, z, ctx, noise, one_step=False):
+ if self.T == 0:
+ return y, z
if one_step: # HRM 1-step gradient
with torch.no_grad():
for _ in range(self.T - 1):
- z = self._inner(z, h0, ei, noise)
+ z = self._z_step(y, z, ctx, noise)
z = z.detach()
- return self._inner(z, h0, ei, noise) # only last inner carries grad
+ z = self._z_step(y, z, ctx, noise) # only last inner carries grad
+ y = self._y_step(y, z, noise)
+ return y, z
for _ in range(self.T): # TRM full recursion
- z = self._inner(z, h0, ei, noise)
- return z
+ z = self._z_step(y, z, ctx, noise)
+ y = self._y_step(y, z, noise)
+ return y, z
+
+ def predict(self, y, batch):
+ pooled = global_add_pool(y, batch)
+ return self.head(pooled), self.qhead(pooled).view(-1)
+
+ def forward_trace(self, x, ei, batch, steps, noise=False):
+ ctx = self.aggregate(x, ei)
+ y = ctx
+ z = torch.zeros_like(ctx)
+ preds, q_logits = [], []
+ for s in range(steps):
+ y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step'))
+ pred, q = self.predict(y, batch)
+ preds.append(pred)
+ q_logits.append(q)
+ if s < steps - 1:
+ y, z = y.detach(), z.detach()
+ return preds, q_logits
def forward(self, x, ei, batch, noise=False):
- h0 = self.emb(x)
- z = torch.zeros_like(h0)
+ ctx = self.aggregate(x, ei)
+ y = ctx
+ z = torch.zeros_like(ctx)
preds = []
for s in range(self.n_sup):
if s < self.n_sup - 1:
with torch.no_grad():
- z = self.recurse(z, h0, ei, noise)
- z = z.detach()
+ y, z = self.recurse(y, z, ctx, noise)
+ y, z = y.detach(), z.detach()
else:
- z = self.recurse(z, h0, ei, noise, one_step=(self.grad_mode == '1step'))
- preds.append(self.head(global_add_pool(z, batch)))
- q = self.qhead(global_add_pool(z, batch)).view(-1)
+ y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step'))
+ pred, _ = self.predict(y, batch)
+ preds.append(pred)
+ _, q = self.predict(y, batch)
return preds, q
@@ -102,6 +253,113 @@ def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'):
return (ae / n).tolist(), (ae_or / n).tolist()
+@torch.no_grad()
+def evaluate_trace(model, ld, dev, ymu, ysd, steps, adaptive=False):
+ model.eval()
+ ysd_d, ymu_d = ysd.to(dev), ymu.to(dev)
+ ae = torch.zeros(2)
+ n = 0
+ step_sum = 0.0
+ for b in ld:
+ b = b.to(dev)
+ preds, q_logits = model.forward_trace(b.x, b.edge_index, b.batch, steps, noise=False)
+ P = torch.stack(preds, dim=0)
+ if adaptive:
+ Q = torch.stack(q_logits, dim=0)
+ halted = Q > 0
+ any_halt = halted.any(dim=0)
+ first_halt = halted.to(torch.int64).argmax(dim=0)
+ fallback = torch.full_like(first_halt, steps - 1)
+ idx = torch.where(any_halt, first_halt, fallback)
+ chosen = P[idx, torch.arange(P.size(1), device=dev)]
+ step_sum += (idx.to(torch.float32) + 1).sum().item()
+ else:
+ chosen = P[-1]
+ step_sum += steps * b.num_graphs
+ ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ n += b.num_graphs
+ return (ae / n).tolist(), step_sum / max(n, 1)
+
+
+def _split_nodes(t, ptr):
+ return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)]
+
+
+def act_train_step(model, state, replacement_batch, opt, dev, args):
+ replacement = replacement_batch.to_data_list()
+ batch_size = len(replacement)
+ if state is None:
+ state = {
+ 'graphs': [None for _ in range(batch_size)],
+ 'y': [None for _ in range(batch_size)],
+ 'z': [None for _ in range(batch_size)],
+ 'steps': torch.zeros(batch_size, dtype=torch.long, device=dev),
+ 'halted': torch.ones(batch_size, dtype=torch.bool, device=dev),
+ }
+
+ halted_cpu = state['halted'].detach().cpu().tolist()
+ for i, halted in enumerate(halted_cpu):
+ if halted:
+ state['graphs'][i] = replacement[i]
+
+ b = Batch.from_data_list(state['graphs']).to(dev)
+ ctx = model.aggregate(b.x, b.edge_index)
+ ptr = b.ptr
+ y_parts, z_parts = [], []
+ for i in range(batch_size):
+ start, end = ptr[i].item(), ptr[i + 1].item()
+ if halted_cpu[i] or state['y'][i] is None:
+ y_parts.append(ctx[start:end])
+ z_parts.append(torch.zeros_like(ctx[start:end]))
+ else:
+ y_parts.append(state['y'][i].to(dev))
+ z_parts.append(state['z'][i].to(dev))
+ y = torch.cat(y_parts, dim=0)
+ z = torch.cat(z_parts, dim=0)
+
+ opt.zero_grad()
+ y, z = model.recurse(y, z, ctx, noise=False, one_step=(model.grad_mode == '1step'))
+ pred, q = model.predict(y, b.batch)
+ per_graph_err = (pred - b.y).abs().mean(1)
+ pred_loss = per_graph_err.mean()
+ with torch.no_grad():
+ if args.halt_target == 'binary':
+ halt_target = (per_graph_err <= args.halt_norm_threshold).to(q.dtype)
+ else:
+ halt_target = torch.sigmoid((args.halt_norm_threshold - per_graph_err) / args.halt_temp)
+ q_loss = nn.functional.binary_cross_entropy_with_logits(q, halt_target)
+ loss = pred_loss + 0.5 * args.lam_q * q_loss
+ y_det, z_det = y.detach(), z.detach()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ state['y'] = _split_nodes(y_det, ptr)
+ state['z'] = _split_nodes(z_det, ptr)
+ with torch.no_grad():
+ was_halted = state['halted']
+ steps = torch.where(was_halted, torch.zeros_like(state['steps']), state['steps']) + 1
+ halted = (steps >= args.halt_max_steps) | (q.detach() > 0)
+ if args.halt_exploration_prob > 0 and args.halt_max_steps > 1:
+ explore = torch.rand_like(q) < args.halt_exploration_prob
+ min_steps = torch.where(
+ explore,
+ torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev),
+ torch.zeros_like(steps),
+ )
+ halted = halted & (steps >= min_steps)
+ state['steps'] = steps
+ state['halted'] = halted
+
+ return state, {
+ 'loss': float(loss.detach().cpu()),
+ 'pred_loss': float(pred_loss.detach().cpu()),
+ 'q_loss': float(q_loss.detach().cpu()),
+ 'halted_frac': float(state['halted'].to(torch.float32).mean().detach().cpu()),
+ 'steps': float(state['steps'].to(torch.float32).mean().detach().cpu()),
+ }
+
+
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
@@ -111,14 +369,27 @@ def main():
ap.add_argument('--T', type=int, default=3)
ap.add_argument('--n_sup', type=int, default=3)
ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--agg_layers', type=int, default=5)
+ ap.add_argument('--compute_layers', type=int, default=2)
+ ap.add_argument('--view', choices=SUPPORTED_VIEWS, default='gin')
ap.add_argument('--epochs', type=int, default=200)
ap.add_argument('--lr', type=float, default=1e-3)
ap.add_argument('--bs', type=int, default=128)
ap.add_argument('--lam_q', type=float, default=1.0)
+ ap.add_argument('--act', action='store_true',
+ help='train all recurrent depths up to halt_max_steps and train qhead as a halt head')
+ ap.add_argument('--halt_max_steps', type=int, default=8)
+ ap.add_argument('--halt_norm_threshold', type=float, default=0.30)
+ ap.add_argument('--halt_temp', type=float, default=0.10)
+ ap.add_argument('--halt_target', choices=['soft', 'binary'], default='soft')
+ ap.add_argument('--halt_exploration_prob', type=float, default=0.1)
+ ap.add_argument('--loss_mode', choices=['last', 'trace'], default='trace')
ap.add_argument('--seed', type=int, default=0)
+ ap.add_argument('--device', default='auto')
args = ap.parse_args()
torch.manual_seed(args.seed); np.random.seed(args.seed)
- dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ dev = 'cuda' if args.device == 'auto' and torch.cuda.is_available() else (
+ 'cpu' if args.device == 'auto' else args.device)
os.makedirs(OUT, exist_ok=True)
tr, va, te = prepare('train'), prepare('val'), prepare('test')
@@ -127,45 +398,91 @@ def main():
for recs in (tr, va, te):
for r in recs:
r['y'] = (r['y'] - ymu) / ysd
- trl = loader(tr, args.bs, True, drop_last=True)
+ train_data = data_list(tr)
+ trl = loader(train_data, args.bs, True, drop_last=True)
val, tel = loader(va, 256, False), loader(te, 256, False)
- model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode).to(dev)
+ deg = degree_histogram(train_data) if args.view == 'pna' else None
+ model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode,
+ agg_layers=args.agg_layers, compute_layers=args.compute_layers,
+ view=args.view, deg=deg).to(dev)
opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
l1 = nn.L1Loss()
+ act_steps = max(1, args.halt_max_steps)
- t0 = time.time(); best_val = 9e9; best = {}; best_state = None
+ t0 = time.time(); best_val = 9e9; best = {}; best_state = None; act_state = None
for ep in range(args.epochs):
model.train()
+ act_metrics = []
for b in trl:
- b = b.to(dev); opt.zero_grad()
- preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
- loss = sum(l1(p, b.y) for p in preds) / len(preds)
- with torch.no_grad():
- tq = -(preds[-1] - b.y).abs().mean(1)
- loss = loss + args.lam_q * nn.functional.mse_loss(q, tq)
- loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ if args.act:
+ act_state, metrics = act_train_step(model, act_state, b, opt, dev, args)
+ act_metrics.append(metrics)
+ else:
+ b = b.to(dev); opt.zero_grad()
+ if args.loss_mode == 'trace':
+ preds, q_logits = model.forward_trace(
+ b.x, b.edge_index, b.batch, args.n_sup, noise=model.sigma > 0)
+ q = q_logits[-1]
+ else:
+ preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ loss = sum(l1(p, b.y) for p in preds) / len(preds)
+ with torch.no_grad():
+ tq = -(preds[-1] - b.y).abs().mean(1)
+ loss = loss + args.lam_q * nn.functional.mse_loss(q, tq)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
sched.step()
if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
- vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select)
+ if args.act:
+ vm, _ = evaluate_trace(model, val, dev, ymu, ysd, act_steps, adaptive=False)
+ else:
+ vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select)
if sum(vm) < best_val:
best_val = sum(vm)
- tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select)
- best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo}
+ if args.act:
+ tem, fixed_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=False)
+ tea, adaptive_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=True)
+ best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem,
+ 'test_mae_adaptive': tea, 'fixed_steps': fixed_steps,
+ 'adaptive_steps': adaptive_steps}
+ else:
+ tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select)
+ best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo}
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
- print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True)
+ if args.act and act_metrics:
+ hm = sum(m['halted_frac'] for m in act_metrics) / len(act_metrics)
+ sm = sum(m['steps'] for m in act_metrics) / len(act_metrics)
+ print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]} halt={hm:.2f} train_steps={sm:.2f}", flush=True)
+ else:
+ print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True)
- tag = f"rec_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}_s{args.seed}"
+ act_tag = f"_actfull{act_steps}_{args.halt_target}{args.halt_norm_threshold:g}_e{args.epochs}" if args.act else ""
+ loss_tag = f"_{args.loss_mode}" if (not args.act and args.loss_mode != 'last') else ""
+ view_tag = f"_{args.view}" if args.view != 'gin' else ""
+ tag = f"rec_rrog{view_tag}_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}{loss_tag}{act_tag}_s{args.seed}"
rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1),
- 'dev': dev, 'y_std_raw': ysd.tolist(), **best}
- print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} "
- f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)")
+ 'dev': dev, 'arch': 'rrog_once_agg_node_compute', 'y_std_raw': ysd.tolist(), **best}
+ if args.act:
+ print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} "
+ f"adaptive={[round(x,3) for x in best.get('test_mae_adaptive')]} "
+ f"steps={best.get('adaptive_steps'):.2f}/{best.get('fixed_steps'):.2f} "
+ f"@ep{best.get('ep')} ({rep['sec']}s)")
+ else:
+ print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} "
+ f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)")
with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
json.dump(rep, f, indent=2)
torch.save({'state': best_state or model.state_dict(),
'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup,
- 'sigma': args.sigma, 'grad_mode': args.grad_mode},
+ 'sigma': args.sigma, 'grad_mode': args.grad_mode,
+ 'agg_layers': args.agg_layers, 'compute_layers': args.compute_layers,
+ 'view': args.view,
+ 'loss_mode': args.loss_mode,
+ 'act': args.act, 'act_impl': 'persistent_recycle' if args.act else 'none',
+ 'halt_max_steps': act_steps,
+ 'halt_exploration_prob': args.halt_exploration_prob,
+ 'arch': 'rrog_once_agg_node_compute'},
'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt"))
print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))