summaryrefslogtreecommitdiff
path: root/diag/eval_rec_ttc.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/eval_rec_ttc.py')
-rw-r--r--diag/eval_rec_ttc.py113
1 files changed, 113 insertions, 0 deletions
diff --git a/diag/eval_rec_ttc.py b/diag/eval_rec_ttc.py
new file mode 100644
index 0000000..e0e789e
--- /dev/null
+++ b/diag/eval_rec_ttc.py
@@ -0,0 +1,113 @@
+"""Evaluate deterministic test-time compute for RRoG/TRM-on-GNN checkpoints.
+
+Loads a trained ``diag/train_rec.py`` checkpoint, overrides ``model.T`` at eval time,
+and reports raw MAE on ZINC cycle-count. This separates training-time T from
+test-time recursion depth.
+
+Run:
+ PYTHONPATH=/home/yurenh2/rrog python3 diag/eval_rec_ttc.py \
+ --ckpt runs/ckpt_rec_rrog_full_sig0.0_K1_bestq_T1_ns3_s0.pt --eval_Ts 0 1 2 3 5 8
+"""
+import argparse
+import json
+import os
+
+import torch
+
+from diag.train_rec import RecGIN, evaluate, evaluate_trace, loader
+from diag.train_cycle import prepare
+
+OUT = '/home/yurenh2/rrog/runs'
+
+
+def load_model(ckpt, dev):
+ ck = torch.load(ckpt, weights_only=False, map_location='cpu')
+ cfg = ck['cfg']
+ model = RecGIN(
+ cfg['n_atom'],
+ cfg['hidden'],
+ cfg['T'],
+ cfg['n_sup'],
+ sigma=0.0,
+ grad_mode=cfg.get('grad_mode', 'full'),
+ agg_layers=cfg.get('agg_layers', 5),
+ compute_layers=cfg.get('compute_layers', 2),
+ ).to(dev)
+ model.load_state_dict(ck['state'])
+ model.eval()
+ return model, ck
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--ckpt', required=True)
+ ap.add_argument('--eval_Ts', type=int, nargs='+', default=[0, 1, 2, 3, 5, 8])
+ ap.add_argument('--eval_n_sup', type=int, default=None)
+ ap.add_argument('--adaptive', action='store_true',
+ help='choose the first trace step with q_halt > 0, otherwise the last step')
+ ap.add_argument('--bs', type=int, default=256)
+ ap.add_argument('--split', choices=['val', 'test'], default='test')
+ args = ap.parse_args()
+
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ model, ck = load_model(args.ckpt, dev)
+ cfg = ck['cfg']
+ ymu, ysd = ck['ymu'], ck['ysd']
+
+ recs = prepare(args.split)
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+ ld = loader(recs, args.bs, False)
+
+ original_T = model.T
+ original_n_sup = model.n_sup
+ if args.eval_n_sup is not None:
+ model.n_sup = args.eval_n_sup
+ elif args.adaptive and cfg.get('act'):
+ model.n_sup = cfg.get('halt_max_steps', cfg['n_sup'])
+ eval_n_sup_used = model.n_sup
+
+ rows = []
+ print(f"ckpt={os.path.basename(args.ckpt)} train_T={cfg['T']} train_n_sup={cfg['n_sup']} split={args.split} adaptive={args.adaptive}")
+ print(f"{'T_eval':>6} {'n_sup':>6} {'mae_c5':>10} {'mae_c6':>10} {'sum':>10} {'avg_steps':>10}")
+ for T in args.eval_Ts:
+ model.T = T
+ if args.adaptive:
+ mae, avg_steps = evaluate_trace(model, ld, dev, ymu, ysd, model.n_sup, adaptive=True)
+ else:
+ mae, _ = evaluate(model, ld, dev, ymu, ysd, K=1, select='none')
+ avg_steps = float(model.n_sup)
+ row = {
+ 'T_eval': T,
+ 'n_sup_eval': model.n_sup,
+ 'adaptive': args.adaptive,
+ 'avg_steps': avg_steps,
+ 'mae': mae,
+ 'mae_sum': float(sum(mae)),
+ }
+ rows.append(row)
+ print(f"{T:6d} {model.n_sup:6d} {mae[0]:10.4f} {mae[1]:10.4f} {sum(mae):10.4f} {avg_steps:10.2f}")
+
+ model.T = original_T
+ model.n_sup = original_n_sup
+
+ os.makedirs(OUT, exist_ok=True)
+ base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
+ out = {
+ 'ckpt': args.ckpt,
+ 'split': args.split,
+ 'train_T': cfg['T'],
+ 'train_n_sup': cfg['n_sup'],
+ 'eval_n_sup': eval_n_sup_used,
+ 'adaptive': args.adaptive,
+ 'rows': rows,
+ }
+ suffix = "_adaptive" if args.adaptive else ""
+ fp = os.path.join(OUT, f"ttc_{base}_ns{out['eval_n_sup']}_{args.split}{suffix}.json")
+ with open(fp, 'w') as f:
+ json.dump(out, f, indent=2)
+ print("wrote", fp)
+
+
+if __name__ == '__main__':
+ main()