diff options
Diffstat (limited to 'diag/eval_rec_ttc.py')
| -rw-r--r-- | diag/eval_rec_ttc.py | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/diag/eval_rec_ttc.py b/diag/eval_rec_ttc.py new file mode 100644 index 0000000..e0e789e --- /dev/null +++ b/diag/eval_rec_ttc.py @@ -0,0 +1,113 @@ +"""Evaluate deterministic test-time compute for RRoG/TRM-on-GNN checkpoints. + +Loads a trained ``diag/train_rec.py`` checkpoint, overrides ``model.T`` at eval time, +and reports raw MAE on ZINC cycle-count. This separates training-time T from +test-time recursion depth. + +Run: + PYTHONPATH=/home/yurenh2/rrog python3 diag/eval_rec_ttc.py \ + --ckpt runs/ckpt_rec_rrog_full_sig0.0_K1_bestq_T1_ns3_s0.pt --eval_Ts 0 1 2 3 5 8 +""" +import argparse +import json +import os + +import torch + +from diag.train_rec import RecGIN, evaluate, evaluate_trace, loader +from diag.train_cycle import prepare + +OUT = '/home/yurenh2/rrog/runs' + + +def load_model(ckpt, dev): + ck = torch.load(ckpt, weights_only=False, map_location='cpu') + cfg = ck['cfg'] + model = RecGIN( + cfg['n_atom'], + cfg['hidden'], + cfg['T'], + cfg['n_sup'], + sigma=0.0, + grad_mode=cfg.get('grad_mode', 'full'), + agg_layers=cfg.get('agg_layers', 5), + compute_layers=cfg.get('compute_layers', 2), + ).to(dev) + model.load_state_dict(ck['state']) + model.eval() + return model, ck + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--ckpt', required=True) + ap.add_argument('--eval_Ts', type=int, nargs='+', default=[0, 1, 2, 3, 5, 8]) + ap.add_argument('--eval_n_sup', type=int, default=None) + ap.add_argument('--adaptive', action='store_true', + help='choose the first trace step with q_halt > 0, otherwise the last step') + ap.add_argument('--bs', type=int, default=256) + ap.add_argument('--split', choices=['val', 'test'], default='test') + args = ap.parse_args() + + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + model, ck = load_model(args.ckpt, dev) + cfg = ck['cfg'] + ymu, ysd = ck['ymu'], ck['ysd'] + + recs = prepare(args.split) + for r in recs: + r['y'] = (r['y'] - ymu) / ysd + ld = loader(recs, args.bs, False) + + original_T = model.T + original_n_sup = model.n_sup + if args.eval_n_sup is not None: + model.n_sup = args.eval_n_sup + elif args.adaptive and cfg.get('act'): + model.n_sup = cfg.get('halt_max_steps', cfg['n_sup']) + eval_n_sup_used = model.n_sup + + rows = [] + print(f"ckpt={os.path.basename(args.ckpt)} train_T={cfg['T']} train_n_sup={cfg['n_sup']} split={args.split} adaptive={args.adaptive}") + print(f"{'T_eval':>6} {'n_sup':>6} {'mae_c5':>10} {'mae_c6':>10} {'sum':>10} {'avg_steps':>10}") + for T in args.eval_Ts: + model.T = T + if args.adaptive: + mae, avg_steps = evaluate_trace(model, ld, dev, ymu, ysd, model.n_sup, adaptive=True) + else: + mae, _ = evaluate(model, ld, dev, ymu, ysd, K=1, select='none') + avg_steps = float(model.n_sup) + row = { + 'T_eval': T, + 'n_sup_eval': model.n_sup, + 'adaptive': args.adaptive, + 'avg_steps': avg_steps, + 'mae': mae, + 'mae_sum': float(sum(mae)), + } + rows.append(row) + print(f"{T:6d} {model.n_sup:6d} {mae[0]:10.4f} {mae[1]:10.4f} {sum(mae):10.4f} {avg_steps:10.2f}") + + model.T = original_T + model.n_sup = original_n_sup + + os.makedirs(OUT, exist_ok=True) + base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '') + out = { + 'ckpt': args.ckpt, + 'split': args.split, + 'train_T': cfg['T'], + 'train_n_sup': cfg['n_sup'], + 'eval_n_sup': eval_n_sup_used, + 'adaptive': args.adaptive, + 'rows': rows, + } + suffix = "_adaptive" if args.adaptive else "" + fp = os.path.join(OUT, f"ttc_{base}_ns{out['eval_n_sup']}_{args.split}{suffix}.json") + with open(fp, 'w') as f: + json.dump(out, f, indent=2) + print("wrote", fp) + + +if __name__ == '__main__': + main() |
