1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
|
"""Evaluate deterministic test-time compute for RRoG/TRM-on-GNN checkpoints.
Loads a trained ``diag/train_rec.py`` checkpoint, overrides ``model.T`` at eval time,
and reports raw MAE on ZINC cycle-count. This separates training-time T from
test-time recursion depth.
Run:
PYTHONPATH=/home/yurenh2/rrog python3 diag/eval_rec_ttc.py \
--ckpt runs/ckpt_rec_rrog_full_sig0.0_K1_bestq_T1_ns3_s0.pt --eval_Ts 0 1 2 3 5 8
"""
import argparse
import json
import os
import torch
from diag.train_rec import RecGIN, evaluate, evaluate_trace, loader
from diag.train_cycle import prepare
OUT = '/home/yurenh2/rrog/runs'
def load_model(ckpt, dev):
ck = torch.load(ckpt, weights_only=False, map_location='cpu')
cfg = ck['cfg']
model = RecGIN(
cfg['n_atom'],
cfg['hidden'],
cfg['T'],
cfg['n_sup'],
sigma=0.0,
grad_mode=cfg.get('grad_mode', 'full'),
agg_layers=cfg.get('agg_layers', 5),
compute_layers=cfg.get('compute_layers', 2),
).to(dev)
model.load_state_dict(ck['state'])
model.eval()
return model, ck
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--ckpt', required=True)
ap.add_argument('--eval_Ts', type=int, nargs='+', default=[0, 1, 2, 3, 5, 8])
ap.add_argument('--eval_n_sup', type=int, default=None)
ap.add_argument('--adaptive', action='store_true',
help='choose the first trace step with q_halt > 0, otherwise the last step')
ap.add_argument('--bs', type=int, default=256)
ap.add_argument('--split', choices=['val', 'test'], default='test')
args = ap.parse_args()
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
model, ck = load_model(args.ckpt, dev)
cfg = ck['cfg']
ymu, ysd = ck['ymu'], ck['ysd']
recs = prepare(args.split)
for r in recs:
r['y'] = (r['y'] - ymu) / ysd
ld = loader(recs, args.bs, False)
original_T = model.T
original_n_sup = model.n_sup
if args.eval_n_sup is not None:
model.n_sup = args.eval_n_sup
elif args.adaptive and cfg.get('act'):
model.n_sup = cfg.get('halt_max_steps', cfg['n_sup'])
eval_n_sup_used = model.n_sup
rows = []
print(f"ckpt={os.path.basename(args.ckpt)} train_T={cfg['T']} train_n_sup={cfg['n_sup']} split={args.split} adaptive={args.adaptive}")
print(f"{'T_eval':>6} {'n_sup':>6} {'mae_c5':>10} {'mae_c6':>10} {'sum':>10} {'avg_steps':>10}")
for T in args.eval_Ts:
model.T = T
if args.adaptive:
mae, avg_steps = evaluate_trace(model, ld, dev, ymu, ysd, model.n_sup, adaptive=True)
else:
mae, _ = evaluate(model, ld, dev, ymu, ysd, K=1, select='none')
avg_steps = float(model.n_sup)
row = {
'T_eval': T,
'n_sup_eval': model.n_sup,
'adaptive': args.adaptive,
'avg_steps': avg_steps,
'mae': mae,
'mae_sum': float(sum(mae)),
}
rows.append(row)
print(f"{T:6d} {model.n_sup:6d} {mae[0]:10.4f} {mae[1]:10.4f} {sum(mae):10.4f} {avg_steps:10.2f}")
model.T = original_T
model.n_sup = original_n_sup
os.makedirs(OUT, exist_ok=True)
base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
out = {
'ckpt': args.ckpt,
'split': args.split,
'train_T': cfg['T'],
'train_n_sup': cfg['n_sup'],
'eval_n_sup': eval_n_sup_used,
'adaptive': args.adaptive,
'rows': rows,
}
suffix = "_adaptive" if args.adaptive else ""
fp = os.path.join(OUT, f"ttc_{base}_ns{out['eval_n_sup']}_{args.split}{suffix}.json")
with open(fp, 'w') as f:
json.dump(out, f, indent=2)
print("wrote", fp)
if __name__ == '__main__':
main()
|