"""Evaluate deterministic test-time compute for RRoG/TRM-on-GNN checkpoints. Loads a trained ``diag/train_rec.py`` checkpoint, overrides ``model.T`` at eval time, and reports raw MAE on ZINC cycle-count. This separates training-time T from test-time recursion depth. Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/eval_rec_ttc.py \ --ckpt runs/ckpt_rec_rrog_full_sig0.0_K1_bestq_T1_ns3_s0.pt --eval_Ts 0 1 2 3 5 8 """ import argparse import json import os import torch from diag.train_rec import RecGIN, evaluate, evaluate_trace, loader from diag.train_cycle import prepare OUT = '/home/yurenh2/rrog/runs' def load_model(ckpt, dev): ck = torch.load(ckpt, weights_only=False, map_location='cpu') cfg = ck['cfg'] model = RecGIN( cfg['n_atom'], cfg['hidden'], cfg['T'], cfg['n_sup'], sigma=0.0, grad_mode=cfg.get('grad_mode', 'full'), agg_layers=cfg.get('agg_layers', 5), compute_layers=cfg.get('compute_layers', 2), ).to(dev) model.load_state_dict(ck['state']) model.eval() return model, ck def main(): ap = argparse.ArgumentParser() ap.add_argument('--ckpt', required=True) ap.add_argument('--eval_Ts', type=int, nargs='+', default=[0, 1, 2, 3, 5, 8]) ap.add_argument('--eval_n_sup', type=int, default=None) ap.add_argument('--adaptive', action='store_true', help='choose the first trace step with q_halt > 0, otherwise the last step') ap.add_argument('--bs', type=int, default=256) ap.add_argument('--split', choices=['val', 'test'], default='test') args = ap.parse_args() dev = 'cuda' if torch.cuda.is_available() else 'cpu' model, ck = load_model(args.ckpt, dev) cfg = ck['cfg'] ymu, ysd = ck['ymu'], ck['ysd'] recs = prepare(args.split) for r in recs: r['y'] = (r['y'] - ymu) / ysd ld = loader(recs, args.bs, False) original_T = model.T original_n_sup = model.n_sup if args.eval_n_sup is not None: model.n_sup = args.eval_n_sup elif args.adaptive and cfg.get('act'): model.n_sup = cfg.get('halt_max_steps', cfg['n_sup']) eval_n_sup_used = model.n_sup rows = [] print(f"ckpt={os.path.basename(args.ckpt)} train_T={cfg['T']} train_n_sup={cfg['n_sup']} split={args.split} adaptive={args.adaptive}") print(f"{'T_eval':>6} {'n_sup':>6} {'mae_c5':>10} {'mae_c6':>10} {'sum':>10} {'avg_steps':>10}") for T in args.eval_Ts: model.T = T if args.adaptive: mae, avg_steps = evaluate_trace(model, ld, dev, ymu, ysd, model.n_sup, adaptive=True) else: mae, _ = evaluate(model, ld, dev, ymu, ysd, K=1, select='none') avg_steps = float(model.n_sup) row = { 'T_eval': T, 'n_sup_eval': model.n_sup, 'adaptive': args.adaptive, 'avg_steps': avg_steps, 'mae': mae, 'mae_sum': float(sum(mae)), } rows.append(row) print(f"{T:6d} {model.n_sup:6d} {mae[0]:10.4f} {mae[1]:10.4f} {sum(mae):10.4f} {avg_steps:10.2f}") model.T = original_T model.n_sup = original_n_sup os.makedirs(OUT, exist_ok=True) base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '') out = { 'ckpt': args.ckpt, 'split': args.split, 'train_T': cfg['T'], 'train_n_sup': cfg['n_sup'], 'eval_n_sup': eval_n_sup_used, 'adaptive': args.adaptive, 'rows': rows, } suffix = "_adaptive" if args.adaptive else "" fp = os.path.join(OUT, f"ttc_{base}_ns{out['eval_n_sup']}_{args.split}{suffix}.json") with open(fp, 'w') as f: json.dump(out, f, indent=2) print("wrote", fp) if __name__ == '__main__': main()