summaryrefslogtreecommitdiff
path: root/diag/eval_rec_ttc.py
blob: e0e789ea93c6e4e395aab6c9156683957cc56687 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Evaluate deterministic test-time compute for RRoG/TRM-on-GNN checkpoints.

Loads a trained ``diag/train_rec.py`` checkpoint, overrides ``model.T`` at eval time,
and reports raw MAE on ZINC cycle-count. This separates training-time T from
test-time recursion depth.

Run:
  PYTHONPATH=/home/yurenh2/rrog python3 diag/eval_rec_ttc.py \
    --ckpt runs/ckpt_rec_rrog_full_sig0.0_K1_bestq_T1_ns3_s0.pt --eval_Ts 0 1 2 3 5 8
"""
import argparse
import json
import os

import torch

from diag.train_rec import RecGIN, evaluate, evaluate_trace, loader
from diag.train_cycle import prepare

OUT = '/home/yurenh2/rrog/runs'


def load_model(ckpt, dev):
    ck = torch.load(ckpt, weights_only=False, map_location='cpu')
    cfg = ck['cfg']
    model = RecGIN(
        cfg['n_atom'],
        cfg['hidden'],
        cfg['T'],
        cfg['n_sup'],
        sigma=0.0,
        grad_mode=cfg.get('grad_mode', 'full'),
        agg_layers=cfg.get('agg_layers', 5),
        compute_layers=cfg.get('compute_layers', 2),
    ).to(dev)
    model.load_state_dict(ck['state'])
    model.eval()
    return model, ck


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--ckpt', required=True)
    ap.add_argument('--eval_Ts', type=int, nargs='+', default=[0, 1, 2, 3, 5, 8])
    ap.add_argument('--eval_n_sup', type=int, default=None)
    ap.add_argument('--adaptive', action='store_true',
                    help='choose the first trace step with q_halt > 0, otherwise the last step')
    ap.add_argument('--bs', type=int, default=256)
    ap.add_argument('--split', choices=['val', 'test'], default='test')
    args = ap.parse_args()

    dev = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, ck = load_model(args.ckpt, dev)
    cfg = ck['cfg']
    ymu, ysd = ck['ymu'], ck['ysd']

    recs = prepare(args.split)
    for r in recs:
        r['y'] = (r['y'] - ymu) / ysd
    ld = loader(recs, args.bs, False)

    original_T = model.T
    original_n_sup = model.n_sup
    if args.eval_n_sup is not None:
        model.n_sup = args.eval_n_sup
    elif args.adaptive and cfg.get('act'):
        model.n_sup = cfg.get('halt_max_steps', cfg['n_sup'])
    eval_n_sup_used = model.n_sup

    rows = []
    print(f"ckpt={os.path.basename(args.ckpt)} train_T={cfg['T']} train_n_sup={cfg['n_sup']} split={args.split} adaptive={args.adaptive}")
    print(f"{'T_eval':>6} {'n_sup':>6} {'mae_c5':>10} {'mae_c6':>10} {'sum':>10} {'avg_steps':>10}")
    for T in args.eval_Ts:
        model.T = T
        if args.adaptive:
            mae, avg_steps = evaluate_trace(model, ld, dev, ymu, ysd, model.n_sup, adaptive=True)
        else:
            mae, _ = evaluate(model, ld, dev, ymu, ysd, K=1, select='none')
            avg_steps = float(model.n_sup)
        row = {
            'T_eval': T,
            'n_sup_eval': model.n_sup,
            'adaptive': args.adaptive,
            'avg_steps': avg_steps,
            'mae': mae,
            'mae_sum': float(sum(mae)),
        }
        rows.append(row)
        print(f"{T:6d} {model.n_sup:6d} {mae[0]:10.4f} {mae[1]:10.4f} {sum(mae):10.4f} {avg_steps:10.2f}")

    model.T = original_T
    model.n_sup = original_n_sup

    os.makedirs(OUT, exist_ok=True)
    base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
    out = {
        'ckpt': args.ckpt,
        'split': args.split,
        'train_T': cfg['T'],
        'train_n_sup': cfg['n_sup'],
        'eval_n_sup': eval_n_sup_used,
        'adaptive': args.adaptive,
        'rows': rows,
    }
    suffix = "_adaptive" if args.adaptive else ""
    fp = os.path.join(OUT, f"ttc_{base}_ns{out['eval_n_sup']}_{args.split}{suffix}.json")
    with open(fp, 'w') as f:
        json.dump(out, f, indent=2)
    print("wrote", fp)


if __name__ == '__main__':
    main()