diff options
41 files changed, 3874 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1a9b622 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__/ +*.py[cod] + +.venv/ +venv/ + +data/ +runs/ + +.DS_Store diff --git a/diag/__init__.py b/diag/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/diag/__init__.py diff --git a/diag/aggregate.py b/diag/aggregate.py new file mode 100644 index 0000000..b0f737a --- /dev/null +++ b/diag/aggregate.py @@ -0,0 +1,54 @@ +"""Aggregate multi-seed coloring results -> mean+/-std per (grad_mode, pe, contract).""" +import glob, json +import numpy as np +from collections import defaultdict + +R = '/home/yurenh2/rrog/runs' + + +def ms(xs): + a = np.array([x for x in xs if x is not None], dtype=float) + return f"{a.mean():.3f}±{a.std():.3f} (n={len(a)})" if len(a) else "—" + + +def load(pat): + out = [] + for f in glob.glob(pat): + try: + out.append(json.load(open(f))) + except Exception: + pass + return out + + +def key(d): + return (d.get('conv', 'gin'), d.get('pe'), d.get('grad_mode'), 'ctr' if d.get('contract') else '-') + + +solve, le, ml = defaultdict(list), defaultdict(list), defaultdict(list) +pk, ls, au, det = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list) + +for d in load(f"{R}/color_*.json"): + if 'solve_rate' in d and d.get('pe') is not None: + solve[key(d)].append(d['solve_rate']) +for d in load(f"{R}/le_color_*.json"): + le[key(d)].append(d.get('auroc')) + ml[key(d)].append(d.get('mean_lam')) +for d in load(f"{R}/ptrm_color_*.json"): + k = key(d); det[k].append(d.get('det')) + s2 = d.get('sigmas', {}).get('0.2') + if s2: + pk[k].append(s2.get('passk')); ls[k].append(s2.get('lamsel')); au[k].append(s2.get('auroc')) + +print("=== best solve_rate (deterministic, EMA) ===") +for k in sorted(solve, key=str): + print(f" {k}: {ms(solve[k])}") +print("=== LE AUROC(fail|lambda1) ===") +for k in sorted(le, key=str): + print(f" {k}: {ms(le[k])}") +print("=== LE mean_lambda1 (forced-contraction dose) ===") +for k in sorted(ml, key=str): + print(f" {k}: {ms(ml[k])}") +print("=== PTRM sigma=0.2 ===") +for k in sorted(pk, key=str): + print(f" {k}: det {ms(det[k])} | pass@K {ms(pk[k])} | lambda-sel {ms(ls[k])} | AUROC {ms(au[k])}") diff --git a/diag/cin_color.py b/diag/cin_color.py new file mode 100644 index 0000000..324215f --- /dev/null +++ b/diag/cin_color.py @@ -0,0 +1,144 @@ +"""H4: Recursive CIN-lite (topological / cell-complex) backbone on graph 3-coloring. + +Augment each graph with ring 2-cells from the cycle basis: add one hypernode per basis cycle, +connected to its member nodes. Messages flow node->ring->node (topological message passing over +rings). Run the shared recursive GIN on the augmented (nodes + ring-cells) graph; decode colors +on the ORIGINAL nodes only. Self-contained: train (EMA, best solve) + LE + PTRM; writes +color_/le_/ptrm_ JSON (conv='cin') for diag/aggregate.py. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/cin_color.py --grad_mode full --seed 0 +""" +import argparse, json, os, time +import numpy as np +import networkx as nx +import torch +import torch.nn.functional as F +from torch_geometric.data import Data, Batch +from diag.train_color import make_split, featurize, RecGINColor, lyap1, OUT +try: + from sklearn.metrics import roc_auc_score +except Exception: + roc_auc_score = None + +N = 50 + + +def dense_A(edge_index, n): + A = torch.zeros(n, n) + if edge_index.shape[1]: + A[edge_index[0], edge_index[1]] = 1.0 + return torch.maximum(A, A.t()) + + +def augment(g): + n = g['n']; ei = g['edge_index'] + G = nx.Graph(); G.add_nodes_from(range(n)) + if ei.shape[1]: + G.add_edges_from(ei.t().tolist()) + rings = nx.cycle_basis(G) + R = len(rings) + src, dst = ei[0].tolist(), ei[1].tolist() + for r, cyc in enumerate(rings): + rn = n + r + for node in cyc: + src += [node, rn]; dst += [rn, node] + aug_ei = torch.tensor([src, dst], dtype=torch.long) if src else torch.zeros((2, 0), dtype=torch.long) + d = g['rfeat'].shape[1] + nf = torch.cat([g['rfeat'], torch.tensor([[1.0, 0.0]]).repeat(n, 1)], dim=1) + rf = torch.cat([torch.zeros(R, d), torch.tensor([[0.0, 1.0]]).repeat(R, 1)], dim=1) if R else torch.zeros(0, d + 2) + return Data(x=torch.cat([nf, rf], dim=0), edge_index=aug_ei, num_nodes=n + R) + + +def conf_of(logits, A): + col = logits.argmax(-1) + return int(((col.unsqueeze(0) == col.unsqueeze(1)) & (A > 0)).sum().item() // 2) + + +@torch.no_grad() +def solve_rate(model, graphs, dev, sample=300): + model.eval(); solved = 0 + for g in graphs[:sample]: + d = g['aug'].to(dev) + lg = model(d.x, d.edge_index)[-1][:N] + solved += int(conf_of(lg, g['A'].to(dev)) == 0) + return solved / len(graphs[:sample]) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--M', type=int, default=16) + ap.add_argument('--seed', type=int, default=0); ap.add_argument('--sigma', type=float, default=0.2) + ap.add_argument('--K', type=int, default=16); ap.add_argument('--n_graphs', type=int, default=150) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + rng = np.random.default_rng(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16) + te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16) + for g in tr + te: + g['A'] = dense_A(g['edge_index'], g['n']); g['aug'] = augment(g) + in_dim = tr[0]['rfeat'].shape[1] + 2 + model = RecGINColor(in_dim, 128, 3, grad_mode=args.grad_mode, conv='gin').to(dev) + opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + + t0 = time.time(); best = -1; best_state = None + for ep in range(args.epochs): + model.train(); order = rng.permutation(len(tr)) + for s0 in range(0, len(order) - args.M, args.M): + sel = order[s0:s0 + args.M] + b = Batch.from_data_list([tr[i]['aug'] for i in sel]).to(dev) + opt.zero_grad() + logits = model(b.x, b.edge_index, b.batch)[-1] + orig = torch.stack([logits[b.ptr[gi]:b.ptr[gi] + N] for gi in range(len(sel))]) # [M,N,k] + A = torch.stack([tr[i]['A'] for i in sel]).to(dev) + p = F.softmax(orig, -1) + loss = (torch.einsum('bik,bjk->bij', p, p) * A).sum() / (A.sum() + 1e-9) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + with torch.no_grad(): + for kk, v in model.state_dict().items(): + ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v) + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + model.load_state_dict(ema); sr = solve_rate(model, te, dev) + if sr > best: + best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema} + model.load_state_dict(bk) + print(f"ep{ep+1} solve={sr:.3f}", flush=True) + model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval() + + nstep = model.n_sup * model.T + lams, fails = [], []; passk = lamsel = rand = 0; Lr, Sr = [], [] + for gi, g in enumerate(te[:args.n_graphs]): + d = g['aug'].to(dev); A = g['A'].to(dev) + lam0 = lyap1(model, d.x, d.edge_index, nstep, dev, seed=gi) + c0 = conf_of(model(d.x, d.edge_index)[-1][:N], A) + lams.append(lam0); fails.append(int(c0 > 0)) + confs, rl = [], [] + for j in range(args.K): + confs.append(conf_of(model(d.x, d.edge_index, noise=True)[-1][:N], A)) + rl.append(lyap1(model, d.x, d.edge_index, nstep, dev, seed=1000 * gi + j)) + confs, rl = np.array(confs), np.array(rl); sv = confs == 0 + passk += int(sv.any()); lamsel += int(sv[rl.argmin()]); rand += int(sv[0]) + Lr += rl.tolist(); Sr += sv.tolist() + lams, fails = np.array(lams), np.array(fails); s, f = lams[fails == 0], lams[fails == 1] + auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) + n = len(te[:args.n_graphs]) + print(f"[cin/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} " + f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)") + base = f"cin_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}" + com = {'conv': 'cin', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed} + json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w')) + json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())}, open(os.path.join(OUT, f"le_color_{base}.json"), 'w')) + json.dump({**com, 'det': 1 - float(fails.mean()), + 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n}}}, + open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w')) + print(" wrote", base) + + +if __name__ == "__main__": + main() diff --git a/diag/datasets.py b/diag/datasets.py new file mode 100644 index 0000000..cbd236e --- /dev/null +++ b/diag/datasets.py @@ -0,0 +1,70 @@ +"""Synthetic graph datasets for the 1-WL diagnosis.""" +import numpy as np +import networkx as nx + + +def _nx_to_edge_index(G): + G = nx.convert_node_labels_to_integers(G) + n = G.number_of_nodes() + if G.number_of_edges() == 0: + return n, np.zeros((2, 0), dtype=np.int64) + e = np.array(list(G.edges()), dtype=np.int64).T + ei = np.concatenate([e, e[::-1]], axis=1) # undirected -> both directions + return n, ei + + +def circulant(N, offsets): + G = nx.Graph() + G.add_nodes_from(range(N)) + for i in range(N): + for s in offsets: + G.add_edge(i, (i + s) % N) + return G + + +CSL_SKIPS = [2, 3, 4, 5, 6, 9, 11, 12, 13, 16] # 10 classes, N=41 (Murphy et al. 2019) + + +def build_csl(n_per_class=15, N=41, seed=0): + """Circular Skip Links: all graphs 4-regular -> 1-WL collapses to one color (pure H2 anchor).""" + rng = np.random.default_rng(seed) + data = [] + for cls, s in enumerate(CSL_SKIPS): + for _ in range(n_per_class): + G = circulant(N, [1, s]) + perm = rng.permutation(N) + G = nx.relabel_nodes(G, {i: int(perm[i]) for i in range(N)}) + n, ei = _nx_to_edge_index(G) + data.append({'n': n, 'edge_index': ei, 'y': cls}) + return data + + +def build_triangle_count(n_graphs=600, n_nodes=20, kind='regular', deg=3, p=0.2, seed=0): + """Graph-level triangle-count regression. 1-WL cannot count triangles -> measurable H2 floor.""" + rng = np.random.default_rng(seed) + data, tries = [], 0 + while len(data) < n_graphs and tries < n_graphs * 30: + tries += 1 + sd = int(rng.integers(1 << 30)) + try: + G = (nx.random_regular_graph(deg, n_nodes, seed=sd) if kind == 'regular' + else nx.gnp_random_graph(n_nodes, p, seed=sd)) + except Exception: + continue + tri = sum(nx.triangles(G).values()) // 3 + n, ei = _nx_to_edge_index(G) + data.append({'n': n, 'edge_index': ei, 'y': float(tri)}) + return data + + +def canonical_pairs(): + """Graphs for instrument self-test (known 1-WL outcomes).""" + pairs = [('C6', nx.cycle_graph(6)), + ('2C3', nx.disjoint_union(nx.cycle_graph(3), nx.cycle_graph(3))), + ('P4', nx.path_graph(4)), + ('K1,3', nx.star_graph(3))] + out = {} + for name, G in pairs: + n, ei = _nx_to_edge_index(G) + out[name] = {'n': n, 'edge_index': ei, 'tri': sum(nx.triangles(G).values()) // 3} + return out diff --git a/diag/esan_color.py b/diag/esan_color.py new file mode 100644 index 0000000..7be050c --- /dev/null +++ b/diag/esan_color.py @@ -0,0 +1,145 @@ +"""H3: Recursive ESAN (subgraph GNN, DS-GNN node-marking bag) on graph 3-coloring. + +Per graph, pick S anchor nodes; each view = graph + a 1-hot mark on the anchor. Run the SHARED +recursive GIN on all views, average node-logits over views (DeepSets). Marking breaks node +symmetry -> >1-WL. Self-contained: train (EMA, best solve) + LE (lambda on a marked view, +bucket by aggregate solve) + PTRM (K noisy aggregate forwards); writes color_/le_/ptrm_ JSON +(conv='esan') for diag/aggregate.py. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/esan_color.py --grad_mode full --seed 0 +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn.functional as F +from torch_geometric.data import Data, Batch +from diag.train_color import make_split, featurize, RecGINColor, lyap1, OUT +try: + from sklearn.metrics import roc_auc_score +except Exception: + roc_auc_score = None + +S = 4 + + +def dense_A(edge_index, n): + A = torch.zeros(n, n) + if edge_index.shape[1]: + A[edge_index[0], edge_index[1]] = 1.0 + return torch.maximum(A, A.t()) + + +def views_of(g, anchors): + out = [] + for a in anchors: + mark = torch.zeros(g['n'], 1); mark[a] = 1.0 + out.append(Data(x=torch.cat([g['rfeat'], mark], dim=1), edge_index=g['edge_index'], num_nodes=g['n'])) + return out + + +def anchors_for(g, rng): + return rng.choice(g['n'], size=min(S, g['n']), replace=False) + + +def esan_logits(model, g, dev, anchors, noise=False): + b = Batch.from_data_list(views_of(g, anchors)).to(dev) + out = model(b.x, b.edge_index, b.batch, noise=noise)[-1] # [S*n, k] + return out.view(len(anchors), g['n'], -1).mean(0) # [n, k] + + +def conf_of(logits, A): + col = logits.argmax(-1) + return int(((col.unsqueeze(0) == col.unsqueeze(1)) & (A > 0)).sum().item() // 2) + + +@torch.no_grad() +def solve_rate(model, graphs, dev, rng, sample=300): + model.eval(); solved = 0 + for g in graphs[:sample]: + solved += int(conf_of(esan_logits(model, g, dev, anchors_for(g, rng)), g['A'].to(dev)) == 0) + return solved / len(graphs[:sample]) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--M', type=int, default=16) + ap.add_argument('--seed', type=int, default=0); ap.add_argument('--sigma', type=float, default=0.2) + ap.add_argument('--K', type=int, default=16); ap.add_argument('--n_graphs', type=int, default=150) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + rng = np.random.default_rng(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16) + te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16) + for g in tr + te: + g['A'] = dense_A(g['edge_index'], g['n']) + in_dim = tr[0]['rfeat'].shape[1] + 1 + model = RecGINColor(in_dim, 128, 3, grad_mode=args.grad_mode, conv='gin').to(dev) + opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + + t0 = time.time(); best = -1; best_state = None + for ep in range(args.epochs): + model.train(); order = rng.permutation(len(tr)) + for s0 in range(0, len(order) - args.M, args.M): + sel = order[s0:s0 + args.M] + views, As = [], [] + for i in sel: + g = tr[i]; views += views_of(g, anchors_for(g, rng)); As.append(g['A']) + b = Batch.from_data_list(views).to(dev) + opt.zero_grad() + logits = model(b.x, b.edge_index, b.batch, noise=False)[-1].view(args.M, S, 50, 3).mean(1) + Ab = torch.stack(As).to(dev) + p = F.softmax(logits, -1) + loss = (torch.einsum('bik,bjk->bij', p, p) * Ab).sum() / (Ab.sum() + 1e-9) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + with torch.no_grad(): + for kk, v in model.state_dict().items(): + ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v) + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + model.load_state_dict(ema); sr = solve_rate(model, te, dev, rng) + if sr > best: + best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema} + model.load_state_dict(bk) + print(f"ep{ep+1} solve={sr:.3f}", flush=True) + model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval() + + lams, fails = [], []; passk = lamsel = rand = 0; Lr, Sr = [], [] + nstep = model.n_sup * model.T + for gi, g in enumerate(te[:args.n_graphs]): + anc = anchors_for(g, rng) + mark = torch.zeros(g['n'], 1); mark[anc[0]] = 1.0 + xin = torch.cat([g['rfeat'], mark], dim=1).to(dev); ei = g['edge_index'].to(dev) + lam0 = lyap1(model, xin, ei, nstep, dev, seed=gi) + c0 = conf_of(esan_logits(model, g, dev, anc), g['A'].to(dev)) + lams.append(lam0); fails.append(int(c0 > 0)) + confs, rl = [], [] + for j in range(args.K): + confs.append(conf_of(esan_logits(model, g, dev, anc, noise=True), g['A'].to(dev))) + rl.append(lyap1(model, xin, ei, nstep, dev, seed=1000 * gi + j)) + confs, rl = np.array(confs), np.array(rl); sv = confs == 0 + passk += int(sv.any()); lamsel += int(sv[rl.argmin()]); rand += int(sv[0]) + Lr += rl.tolist(); Sr += sv.tolist() + lams, fails = np.array(lams), np.array(fails); s, f = lams[fails == 0], lams[fails == 1] + auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) + n = len(te[:args.n_graphs]) + print(f"[esan/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} " + f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)") + base = f"esan_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}" + com = {'conv': 'esan', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed} + json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w')) + json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())}, + open(os.path.join(OUT, f"le_color_{base}.json"), 'w')) + json.dump({**com, 'det': 1 - float(fails.mean()), + 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n}}}, + open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w')) + print(" wrote", base) + + +if __name__ == "__main__": + main() diff --git a/diag/ff_color.py b/diag/ff_color.py new file mode 100644 index 0000000..29c3b45 --- /dev/null +++ b/diag/ff_color.py @@ -0,0 +1,75 @@ +"""Control: plain FEEDFORWARD (independent-weight, NO recursion, NO deep-supervision) GNN on +graph 3-coloring, to test whether the recursive gcn/gat collapse (.003/.04) is caused by the +RECURSION (shared-weight repeated -> oversmoothing) or is intrinsic to gcn/gat on coloring. + +Sweep conv in {gcn,gat} x depth L in {4,8,16}. L=4 ~ normal usage; L=16 tests whether deep +feedforward also oversmooths. If shallow FF colors well but recursive collapses -> recursion's +fault. If FF is also ~0 at all L -> intrinsic. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ff_color.py --conv gcn --L 4 --seed 0 +""" +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from diag.train_color import make_split, featurize, make_conv, conflict_loss, solve_stats + + +class FF(nn.Module): + def __init__(self, in_dim, hidden, k, L, conv): + super().__init__() + self.conv_type = conv + self.lin_in = nn.Linear(in_dim, hidden) + self.convs = nn.ModuleList([make_conv(conv, hidden) for _ in range(L)]) + self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(L)]) + self.head = nn.Linear(hidden, k) + + def forward(self, xin, ei, batch=None, noise=False): + h = self.lin_in(xin) + for conv, bn in zip(self.convs, self.bns): + h = bn(conv(h, ei)).relu() + return [self.head(h)] + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--conv', choices=['gcn', 'gat', 'gin', 'sage'], default='gcn') + ap.add_argument('--L', type=int, default=4) + ap.add_argument('--epochs', type=int, default=150) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16) + te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16) + in_dim = tr[0]['xin'].shape[1] + trl = DataLoader([Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr], + batch_size=32, shuffle=True, drop_last=True) + model = FF(in_dim, 128, 3, args.L, args.conv).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + best = -1 + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + loss = conflict_loss(model(b.x, b.edge_index)[-1], b.edge_index) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + with torch.no_grad(): + for kk, v in model.state_dict().items(): + ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v) + sched.step() + if (ep + 1) % 30 == 0 or ep == args.epochs - 1: + bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + model.load_state_dict(ema); sr, _ = solve_stats(model, te, dev, sample=300) + best = max(best, sr); model.load_state_dict(bk) + print(f"[ff/{args.conv}/L{args.L}/s{args.seed}] solve={best:.3f}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/diag/lyap.py b/diag/lyap.py new file mode 100644 index 0000000..93b90bf --- /dev/null +++ b/diag/lyap.py @@ -0,0 +1,83 @@ +"""LE diagnostic for the recursive (TRM-ish) GNN — ports the flossing finding to graphs. + +Per-graph top Lyapunov exponent lambda1 of the recursion z <- block(z+h0), via Benettin +power-iteration on a single tangent vector (JVP + renormalize, accumulate log-growth) over +the model's n_sup*T recursion steps. Bucket graphs by success/failure (rounded ring counts +exact) and compare lambda1 distributions + AUROC(fail | lambda1) — mirroring +plot_trm_lyap_hist.py. Hypothesis: failed graphs are MORE chaotic (higher lambda1). + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_full_..._s0.pt +""" +import argparse +import numpy as np +import torch +from diag.train_rec import RecGIN +from diag.train_cycle import prepare +try: + from sklearn.metrics import roc_auc_score +except Exception: + roc_auc_score = None + + +def build(ck, dev): + c = ck['cfg'] + m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode']).to(dev) + m.load_state_dict(ck['state']); m.eval() + return m, c + + +def lyap1(model, x, ei, n_steps, dev, seed=0): + g = torch.Generator(device=dev).manual_seed(seed) + h0 = model.emb(x).detach() + z = torch.zeros_like(h0) + v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12) + def step_fn(zz): + return model.block(zz + h0, ei) + lam = 0.0 + for _ in range(n_steps): + z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v) + z = z_next.detach() + nv = Jv.norm() + lam += torch.log(nv + 1e-12).item() + v = (Jv / (nv + 1e-12)).detach() + return lam / n_steps + + +@torch.no_grad() +def predict(model, x, ei, dev): + batch = torch.zeros(x.size(0), dtype=torch.long, device=dev) + preds, _ = model(x, ei, batch, noise=False) + return preds[-1].view(-1) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--ckpt', required=True) + ap.add_argument('--n_graphs', type=int, default=300) + args = ap.parse_args() + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + ck = torch.load(args.ckpt, weights_only=False) + model, cfg = build(ck, dev) + ymu, ysd = ck['ymu'].to(dev), ck['ysd'].to(dev) + te = prepare('test') + n_steps = cfg['n_sup'] * cfg['T'] + + lams, fails = [], [] + for i, r in enumerate(te[:args.n_graphs]): + x = r['x'].to(dev); ei = r['edge_index'].to(dev) + p = predict(model, x, ei, dev) * ysd + ymu # raw [2] + y = r['y'].to(dev) # raw [2] + fails.append(int(not torch.all(p.round() == y.round()).item())) + lams.append(lyap1(model, x, ei, n_steps, dev, seed=i)) + lams, fails = np.array(lams), np.array(fails) + s, f = lams[fails == 0], lams[fails == 1] + auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) + print(f"[{cfg['grad_mode']}] n={len(lams)} fail_rate={fails.mean():.2f} | " + f"lambda1 SUCC mean {s.mean():+.4f} std {s.std():.4f} (n={len(s)}) | " + f"FAIL mean {f.mean():+.4f} std {f.std():.4f} (n={len(f)}) | " + f"sep(fail-succ)={f.mean()-s.mean() if len(s) and len(f) else float('nan'):+.4f} | " + f"AUROC(fail|lambda1)={auc:.3f} | mean_lambda1={lams.mean():+.4f}") + + +if __name__ == "__main__": + main() diff --git a/diag/models.py b/diag/models.py new file mode 100644 index 0000000..0aa3106 --- /dev/null +++ b/diag/models.py @@ -0,0 +1,42 @@ +"""GIN (1-WL-tight) and GCN (<1-WL) backbones for the diagnosis.""" +import torch.nn as nn +from torch_geometric.nn import GINConv, GCNConv, global_add_pool, global_mean_pool + + +def _mlp(d_in, d_hid, d_out): + return nn.Sequential(nn.Linear(d_in, d_hid), nn.BatchNorm1d(d_hid), nn.ReLU(), + nn.Linear(d_hid, d_out)) + + +class GIN(nn.Module): + """Sum aggregation + MLP update -> injective on multisets -> matches 1-WL.""" + def __init__(self, in_dim, hidden=64, layers=4, out_dim=10): + super().__init__() + self.convs = nn.ModuleList() + d = in_dim + for _ in range(layers): + self.convs.append(GINConv(_mlp(d, hidden, hidden), train_eps=True)) + d = hidden + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index).relu() + return self.head(global_add_pool(x, batch)) + + +class GCN(nn.Module): + """Mean (normalized) aggregation -> non-injective -> strictly below 1-WL (reference baseline).""" + def __init__(self, in_dim, hidden=64, layers=4, out_dim=10): + super().__init__() + self.convs = nn.ModuleList() + d = in_dim + for _ in range(layers): + self.convs.append(GCNConv(d, hidden)) + d = hidden + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index).relu() + return self.head(global_mean_pool(x, batch)) diff --git a/diag/peptides_depth.py b/diag/peptides_depth.py new file mode 100644 index 0000000..d751b31 --- /dev/null +++ b/diag/peptides_depth.py @@ -0,0 +1,66 @@ +"""Depth-resolution analysis on LRGB Peptides-struct (long-range, real, large graphs). + +Best-achievable MSE for a depth-L MPNN == within-(L-round-WL-color) variance of the target. +The curve over L localizes failure cause WITHOUT training: + floor(L=small) high -> under-reached signal still hidden + floor(L) - floor(converged) -> H1a: depth-recoverable (more iteration/depth helps) + floor(converged) -> H2 : 1-WL ceiling (irreducible by any MPNN) +Targets are z-scored per dim, so floor is the fraction of variance unexplained (var=1 baseline). +""" +import numpy as np +from collections import defaultdict +from torch_geometric.datasets import LRGBDataset +from diag import wl + +S = 4000 # graph subsample (for tractable pure-python WL) +MAX_ROUNDS = 40 # cap (>> typical GIN depth; chains converge near their diameter) + + +def floor_at(ghist, Y): + groups = defaultdict(list) + for i, c in enumerate(ghist): + groups[c].append(i) + sse = 0.0 + for idxs in groups.values(): + yy = Y[idxs] + sse += ((yy - yy.mean(0)) ** 2).sum() + return sse / (len(ghist) * Y.shape[1]), len(groups) + + +def main(): + ds = LRGBDataset(root='/home/yurenh2/rrog/data/lrgb', name='Peptides-struct', split='train') + graphs = [ds[i] for i in range(min(S, len(ds)))] + fmap = {} + def fid(row): + t = tuple(row) + if t not in fmap: + fmap[t] = len(fmap) + return fmap[t] + adjs, inits, Y = [], [], [] + for g in graphs: + adjs.append(wl.edges_to_adj(g.num_nodes, g.edge_index.numpy())) + inits.append(np.array([fid(r) for r in g.x.tolist()], dtype=np.int64)) + Y.append(g.y.numpy().reshape(-1)) + Y = np.stack(Y).astype(np.float64) + Y = (Y - Y.mean(0)) / (Y.std(0) + 1e-8) # z-score per target + print(f"subsample={len(graphs)} graphs, {len(fmap)} distinct node-feature ids, targets={Y.shape[1]}") + + import time; t0 = time.time() + node_rounds, ghist_rounds, conv = wl.wl_refine(adjs, inits=inits, max_rounds=MAX_ROUNDS) + print(f"WL refined to round {conv} (cap {MAX_ROUNDS}) in {time.time()-t0:.1f}s") + + print(f"{'L':>4} {'floor_MSE(std)':>14} {'%var_unexpl':>12} {'#graph_colors':>14}") + floors = {} + for L in [0, 1, 2, 3, 4, 5, 8, 16, 32, conv]: + r = min(L, conv) + f, nc = floor_at(ghist_rounds[r], Y) + floors[L] = f + print(f"{L:>4} {f:>14.4f} {100*f:>11.1f}% {nc:>14}") + h2 = floors[conv] + for Lg in [4, 5]: + print(f"\nAt GIN depth L={Lg}: H2 ceiling={h2:.3f} | depth-recoverable H1a (floor[{Lg}]-H2)" + f"={floors[Lg]-h2:.3f} | already-reachable={1-floors[Lg]:.3f} of var") + + +if __name__ == "__main__": + main() diff --git a/diag/ppgn_color.py b/diag/ppgn_color.py new file mode 100644 index 0000000..25db01b --- /dev/null +++ b/diag/ppgn_color.py @@ -0,0 +1,209 @@ +"""H2: Recursive PPGN (higher-order, 3-WL) backbone on graph 3-coloring. + +State is a pair-tensor X[i,j,:] (not node features). The powerful block multiplies two +channel-wise n x n matrices (the 3-WL operation): P = M1 @ M2; out = m3([X, P]). Recurse the +pair-tensor (TRM-style deep supervision), pool to nodes (mean over j), decode colors. +Self-contained: trains (EMA, best solve), then LE diagnostic + PTRM noise/lambda-select; +writes color_/le_/ptrm_ JSONs (conv='ppgn') so diag/aggregate.py folds it into the big table. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ppgn_color.py --grad_mode full --seed 0 +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diag.train_color import make_split, featurize, OUT +try: + from sklearn.metrics import roc_auc_score +except Exception: + roc_auc_score = None + + +def dense_A(edge_index, n): + A = torch.zeros(n, n) + if edge_index.shape[1]: + A[edge_index[0], edge_index[1]] = 1.0 + return torch.maximum(A, A.t()) + + +def mlp(di, dh, do): + return nn.Sequential(nn.Linear(di, dh), nn.ReLU(), nn.Linear(dh, do)) + + +class RecPPGN(nn.Module): + def __init__(self, in_dim, hidden=64, k=3, T=3, n_sup=3, grad_mode='full', sigma=0.0): + super().__init__() + self.node_in = nn.Linear(in_dim, hidden) + self.adj_emb = nn.Embedding(2, hidden) + self.m1 = mlp(hidden, hidden, hidden); self.m2 = mlp(hidden, hidden, hidden) + self.m3 = mlp(2 * hidden, hidden, hidden) + self.ln = nn.LayerNorm(hidden) + self.head = nn.Linear(hidden, k) + self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma + + def init_pair(self, xin, A): # xin [B,n,in], A [B,n,n] + h = self.node_in(xin) + X = h.unsqueeze(2) + h.unsqueeze(1) # [B,n,n,hidden] + return X + self.adj_emb(A.long()) + + def block(self, X): # X [B,n,n,h] + M1, M2 = self.m1(X), self.m2(X) + P = torch.einsum('bikc,bkjc->bijc', M1, M2) / X.shape[1] + return self.ln(self.m3(torch.cat([X, P], dim=-1))) + + def _inner(self, X, X0, noise): + X = self.block(X + X0) + if noise and self.sigma > 0: + X = X + self.sigma * torch.randn_like(X) + return X + + def recurse(self, X, X0, noise, one_step=False): + if one_step: + with torch.no_grad(): + for _ in range(self.T - 1): + X = self._inner(X, X0, noise) + X = X.detach() + return self._inner(X, X0, noise) + for _ in range(self.T): + X = self._inner(X, X0, noise) + return X + + def forward(self, xin, A, noise=False): + X0 = self.init_pair(xin, A) + X = torch.zeros_like(X0) + outs = [] + for s in range(self.n_sup): + X = self.recurse(X, X0, noise, one_step=(self.grad_mode == '1step')) + outs.append(self.head(X.mean(2))) # pool over j -> [B,n,k] + X = X.detach() + return outs + + +def conflict_loss(logits, A): # logits [B,n,k], A [B,n,n] + p = F.softmax(logits, dim=-1) + return (torch.einsum('bik,bjk->bij', p, p) * A).sum() / (A.sum() + 1e-9) + + +def batches(graphs, bs, shuffle, dev): + idx = np.arange(len(graphs)) + if shuffle: + np.random.shuffle(idx) + for s in range(0, len(idx) - (len(idx) % bs if shuffle else 0), bs): + sel = idx[s:s + bs] + if len(sel) == 0: + continue + X = torch.stack([graphs[i]['xin'] for i in sel]).to(dev) + A = torch.stack([graphs[i]['A'] for i in sel]).to(dev) + yield X, A + + +@torch.no_grad() +def solve_stats(model, graphs, dev, sample=300): + model.eval(); solved = 0; tot = 0 + for g in graphs[:sample]: + A = g['A'].unsqueeze(0).to(dev) + col = model(g['xin'].unsqueeze(0).to(dev), A)[-1][0].argmax(-1) + conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2 + solved += int(conf == 0); tot += 1 + return solved / tot + + +def lyap1_and_solved(model, g, dev, seed, sigma=0.0): + gen = torch.Generator(device=dev).manual_seed(seed) + xin = g['xin'].unsqueeze(0).to(dev); A = g['A'].unsqueeze(0).to(dev) + X0 = model.init_pair(xin, A) + X = torch.zeros_like(X0) + v = torch.randn(X.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12) + lam = 0.0 + for _ in range(model.n_sup * model.T): + Xd, Jv = torch.autograd.functional.jvp(lambda XX: model.block(XX + X0), X, v) + nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach() + X = Xd.detach() + if sigma > 0: + X = X + sigma * torch.randn(X.shape, generator=gen, device=dev) + col = model.head(X.mean(2))[0].argmax(-1) + conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2 + return lam / (model.n_sup * model.T), conf + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--bs', type=int, default=16) + ap.add_argument('--hidden', type=int, default=64); ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--sigma', type=float, default=0.2); ap.add_argument('--K', type=int, default=16) + ap.add_argument('--n_graphs', type=int, default=150) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16) + te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16) + for g in tr + te: + g['A'] = dense_A(g['edge_index'], g['n']) + in_dim = tr[0]['xin'].shape[1] + + model = RecPPGN(in_dim, args.hidden, 3, grad_mode=args.grad_mode).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + t0 = time.time(); best = -1; best_state = None + for ep in range(args.epochs): + model.train() + for X, A in batches(tr, args.bs, True, dev): + opt.zero_grad() + outs = model(X, A, noise=False) + loss = sum(conflict_loss(o, A) for o in outs) / len(outs) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + with torch.no_grad(): + for kk, v in model.state_dict().items(): + ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v) + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + model.load_state_dict(ema); sr = solve_stats(model, te, dev) + if sr > best: + best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema} + model.load_state_dict(bk) + print(f"ep{ep+1} solve={sr:.3f}", flush=True) + model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval() + + # LE + PTRM (noise + lambda-select) on test + lams, fails = [], [] + passk = lamsel = rand = 0 + Lr, Sr = [], [] + for gi, g in enumerate(te[:args.n_graphs]): + lam0, c0 = lyap1_and_solved(model, g, dev, seed=gi, sigma=0.0) + lams.append(lam0); fails.append(int(c0 > 0)) + res = [lyap1_and_solved(model, g, dev, seed=1000 * gi + j, sigma=args.sigma) for j in range(args.K)] + confs = np.array([c for _, c in res]); rl = np.array([l for l, _ in res]) + solved = confs == 0 + passk += int(solved.any()); lamsel += int(solved[rl.argmin()]); rand += int(solved[0]) + Lr += rl.tolist(); Sr += solved.tolist() + lams, fails = np.array(lams), np.array(fails) + s, f = lams[fails == 0], lams[fails == 1] + auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) + Lr, Sr = np.array(Lr), np.array(Sr) + pauc = (roc_auc_score(Sr.astype(int), -Lr) if roc_auc_score and Sr.any() and (~Sr).any() else float('nan')) + n = len(te[:args.n_graphs]) + print(f"[ppgn/{args.grad_mode}] solve={best:.3f} | LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} | " + f"PTRM det={1 - fails.mean():.3f} passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)") + + base = f"ppgn_full_none_n50_k3_p0.2_T3_ns3_s{args.seed}" if args.grad_mode == 'full' \ + else f"ppgn_1step_none_n50_k3_p0.2_T3_ns3_s{args.seed}" + com = {'conv': 'ppgn', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed} + json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'), indent=2) + json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean()), + 'lam_solved': (float(s.mean()) if len(s) else None), + 'lam_unsolved': (float(f.mean()) if len(f) else None)}, + open(os.path.join(OUT, f"le_color_{base}.json"), 'w'), indent=2) + json.dump({**com, 'det': 1 - float(fails.mean()), + 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n, + 'perRoll': float(Sr.mean()), 'auroc': float(pauc)}}}, + open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'), indent=2) + print(" wrote color_/le_color_/ptrm_color_", base) + + +if __name__ == "__main__": + main() diff --git a/diag/ptrm_color.py b/diag/ptrm_color.py new file mode 100644 index 0000000..4004297 --- /dev/null +++ b/diag/ptrm_color.py @@ -0,0 +1,85 @@ +"""Step-2(a): PTRM test-time noise + lambda-based selection on a trained coloring model +(any backbone feature set via cfg pe). Writes a JSON per ckpt for multi-seed aggregation. + +deterministic / pass@K (conflict-min, ground truth) / lambda-select (min lambda1) / random. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_...pt +""" +import argparse, json, os +import numpy as np +import torch +from diag.train_color import RecGINColor, make_split, featurize +try: + from sklearn.metrics import roc_auc_score +except Exception: + roc_auc_score = None +OUT = '/home/yurenh2/rrog/runs' + + +def rollout(model, xin, ei, sigma, n_sup, T, dev, seed): + gen = torch.Generator(device=dev).manual_seed(seed) + h0 = model.lin_in(xin) + z = torch.zeros_like(h0) + v = torch.randn(h0.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12) + def step(zz): + return model.block(zz + h0, ei) + lam = 0.0 + for _ in range(n_sup * T): + z_det, Jv = torch.autograd.functional.jvp(step, z, v) + nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach() + z = z_det.detach() + if sigma > 0: + z = z + sigma * torch.randn(z.shape, generator=gen, device=dev) + lam /= (n_sup * T) + col = model.head(z).argmax(-1) + conf = (col[ei[0]] == col[ei[1]]).sum().item() // 2 + return conf, lam + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--ckpt', required=True) + ap.add_argument('--K', type=int, default=16) + ap.add_argument('--n_graphs', type=int, default=150) + ap.add_argument('--sigmas', type=float, nargs='+', default=[0.05, 0.1, 0.2, 0.4]) + args = ap.parse_args() + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg'] + deg = torch.tensor(c['deg']) if c.get('deg') else None + model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'], + grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev) + model.load_state_dict(ck['state']); model.eval() + nsup, T = c['n_sup'], c['T'] + te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), c.get('pe', 'none'), c.get('rwse_k', 16)) + te = te[:args.n_graphs]; n = len(te) + + det = sum(rollout(model, r['xin'].to(dev), r['edge_index'].to(dev), 0.0, nsup, T, dev, 0)[0] == 0 + for r in te) / n + out = {'conv': c.get('conv', 'gin'), 'pe': c.get('pe', 'none'), 'seed': c.get('seed'), + 'grad_mode': c['grad_mode'], 'contract': c.get('contract', False), 'det': det, 'sigmas': {}} + print(f"[pe={out['pe']} s{out['seed']}] deterministic solve_rate = {det:.3f} (n={n}, K={args.K})") + print(f"{'sigma':>6} {'pass@K':>8} {'lam-sel':>8} {'random':>8} {'perRoll':>8} {'AUROC(s|-lam)':>14}") + for sigma in args.sigmas: + passk = lamsel = rand = 0 + L, S = [], [] + for gi, r in enumerate(te): + xin = r['xin'].to(dev); ei = r['edge_index'].to(dev) + res = [rollout(model, xin, ei, sigma, nsup, T, dev, 1000 * gi + j) for j in range(args.K)] + confs = np.array([c0 for c0, _ in res]); lams = np.array([l for _, l in res]) + solved = confs == 0 + passk += int(solved.any()); lamsel += int(solved[lams.argmin()]); rand += int(solved[0]) + L += lams.tolist(); S += solved.tolist() + L, S = np.array(L), np.array(S) + auc = (roc_auc_score(S.astype(int), -L) if roc_auc_score and S.any() and (~S).any() else float('nan')) + out['sigmas'][str(sigma)] = {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n, + 'perRoll': float(S.mean()), 'auroc': float(auc)} + print(f"{sigma:>6} {passk/n:>8.3f} {lamsel/n:>8.3f} {rand/n:>8.3f} {S.mean():>8.3f} {auc:>14.3f}") + + base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '') + with open(os.path.join(OUT, f"ptrm_{base}.json"), 'w') as f: + json.dump(out, f, indent=2) + print(" wrote", os.path.join(OUT, f"ptrm_{base}.json")) + + +if __name__ == "__main__": + main() diff --git a/diag/run_archA.sh b/diag/run_archA.sh new file mode 100644 index 0000000..5385152 --- /dev/null +++ b/diag/run_archA.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +# Conv axis (pe=none): gcn, sage, gat. 5 seeds, train+LE+PTRM. (pin GPU at launch.) +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "A start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)" +for s in 0 1 2 3 4; do + for conv in gcn sage gat; do + ck=runs/ckpt_color_${conv}_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt + echo "== A s$s conv=$conv ==" + python3 diag/train_color.py --mode train --conv "$conv" --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train $conv s$s" + python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le $conv s$s" + python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $conv s$s" + done +done +echo "doneA=$(date -Is)" diff --git a/diag/run_archB.sh b/diag/run_archB.sh new file mode 100644 index 0000000..317638f --- /dev/null +++ b/diag/run_archB.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +# GPS transformer backbone + feature axis (gin + lappe / all). 5 seeds. (pin GPU at launch.) +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "B start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)" +for s in 0 1 2 3 4; do + ck=runs/ckpt_color_gps_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt + echo "== B s$s conv=gps ==" + python3 diag/train_color.py --mode train --conv gps --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train gps s$s" + python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le gps s$s" + python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm gps s$s" + for pe in lappe all; do + ck2=runs/ckpt_color_gin_full_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt + echo "== B s$s gin pe=$pe ==" + python3 diag/train_color.py --mode train --conv gin --pe "$pe" --p 0.2 --epochs 150 --seed "$s" || echo "!! train $pe s$s" + python3 diag/train_color.py --mode le --ckpt "$ck2" || echo "!! le $pe s$s" + python3 diag/ptrm_color.py --ckpt "$ck2" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $pe s$s" + done +done +echo "doneB=$(date -Is)" diff --git a/diag/run_cin.sh b/diag/run_cin.sh new file mode 100644 index 0000000..6be362d --- /dev/null +++ b/diag/run_cin.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -uo pipefail +cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog +echo "CIN start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)" +for s in 0 1 2 3 4; do + echo "== cin s$s =="; python3 diag/cin_color.py --grad_mode full --seed "$s" || echo "!! cin s$s" +done +echo "doneCIN=$(date -Is)" diff --git a/diag/run_color.sh b/diag/run_color.sh new file mode 100644 index 0000000..ad3c406 --- /dev/null +++ b/diag/run_color.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +# Step-2 (TRM regime, large output): graph 3-coloring, TRM full vs HRM 1-step + LE diagnostic. +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" +for gm in full 1step; do + echo "===== train $gm =====" + python3 diag/train_color.py --mode train --grad_mode "$gm" --p 0.2 --epochs 150 --seed 0 \ + || echo "!! train $gm failed" +done +echo "===== LE diagnostic (lambda1: solved vs unsolved) =====" +python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_full_n50_k3_p0.2_T3_ns3_s0.pt || echo "!! le full failed" +python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_1step_n50_k3_p0.2_T3_ns3_s0.pt || echo "!! le 1step failed" +echo "done=$(date -Is)" diff --git a/diag/run_cycle.sh b/diag/run_cycle.sh new file mode 100644 index 0000000..f85a6fd --- /dev/null +++ b/diag/run_cycle.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +# Ring-counting (ZINC, [#5-cycles,#6-cycles]) diagnosis: does a real 1-WL ceiling exist, +# and does noise (RNI) vs structured >1-WL (RWSE) break it? +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" +run() { echo "===== $* ====="; python3 diag/train_cycle.py "$@" --layers 5 --epochs 200 --seed 0 || echo "!! FAILED $*"; } +run --conv gin --feat none # 1-WL baseline (should fail to count) +run --conv gcn --feat none # sub-1-WL reference +run --conv gin --feat rni # noise / PTRM-style crude symmetry break +run --conv gin --feat rwse # structured >1-WL positive control +echo "done=$(date -Is)" diff --git a/diag/run_diag.sh b/diag/run_diag.sh new file mode 100644 index 0000000..4a6f31a --- /dev/null +++ b/diag/run_diag.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +# Step-1 diagnosis sweep. CSL = pure-H2 anchor (classification). +# Triangle counting on ER (graphs 1-WL-distinguishable -> H2~0, the H1 end) and +# regular (graphs collapse to one WL color -> H2~var, the ceiling end). +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" +run() { echo "===== $* ====="; python3 diag/train_diag.py "$@" --layers 4 --epochs 300 --seed 0 \ + || echo "!! FAILED $*"; } +for model in gin gcn; do + run --task csl --model "$model" + run --task tri --model "$model" --kind er + run --task tri --model "$model" --kind regular +done +echo "done=$(date -Is)" diff --git a/diag/run_esan.sh b/diag/run_esan.sh new file mode 100644 index 0000000..f87b379 --- /dev/null +++ b/diag/run_esan.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -uo pipefail +cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog +echo "ESAN start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)" +for s in 0 1 2 3 4; do + echo "== esan s$s =="; python3 diag/esan_color.py --grad_mode full --seed "$s" || echo "!! esan s$s" +done +echo "doneESAN=$(date -Is)" diff --git a/diag/run_ff.sh b/diag/run_ff.sh new file mode 100644 index 0000000..e2b7b0a --- /dev/null +++ b/diag/run_ff.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -uo pipefail +cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog +echo "FF start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)" +for conv in gcn gat; do for L in 4 8 16; do for s in 0 1 2; do + python3 diag/ff_color.py --conv "$conv" --L "$L" --seed "$s" || echo "!! ff $conv L$L s$s" +done; done; done +echo "doneFF=$(date -Is)" diff --git a/diag/run_le.sh b/diag/run_le.sh new file mode 100644 index 0000000..8fc31ea --- /dev/null +++ b/diag/run_le.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# Step-2(iii): TRM-ish GIN full-recursion vs 1-step-gradient, then LE diagnostic on each. +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" +python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 --epochs 200 --seed 0 || echo "!! train full failed" +python3 diag/train_rec.py --grad_mode 1step --sigma 0 --K 1 --epochs 200 --seed 0 || echo "!! train 1step failed" +echo "===== LE diagnostic (lambda1: success vs failure) =====" +python3 diag/lyap.py --ckpt runs/ckpt_rec_full_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap full failed" +python3 diag/lyap.py --ckpt runs/ckpt_rec_1step_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap 1step failed" +echo "done=$(date -Is)" diff --git a/diag/run_pe.sh b/diag/run_pe.sh new file mode 100644 index 0000000..8350160 --- /dev/null +++ b/diag/run_pe.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# Roadmap #1: RRoG on PE-augmented backbone. GIN vs GIN+RWSE on coloring: +# does RRoG-noise add headroom on top of static structural encoding, or is it redundant? +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" +for pe in none rwse; do + echo "===== train full pe=$pe =====" + python3 diag/train_color.py --mode train --grad_mode full --pe "$pe" --p 0.2 --epochs 150 --seed 0 \ + || echo "!! train $pe failed" +done +echo "===== LE (full, both pe) =====" +for pe in none rwse; do + python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \ + || echo "!! le $pe failed" +done +echo "===== PTRM noise + lambda-select (both pe) =====" +for pe in none rwse; do + echo "--- pe=$pe ---" + python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \ + --K 16 --n_graphs 150 --sigmas 0.1 0.2 0.4 || echo "!! ptrm $pe failed" +done +echo "done=$(date -Is)" diff --git a/diag/run_pe2.sh b/diag/run_pe2.sh new file mode 100644 index 0000000..db7bd8a --- /dev/null +++ b/diag/run_pe2.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +# Roadmap #2: RRoG on a GSN-style motif backbone (per-node K3/wedge substructure counts). +# full-recursion, 5 seeds; train + LE + PTRM(sigma=0.2). Aggregate vs none/rwse. +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" +for s in 0 1 2 3 4; do + ck=runs/ckpt_color_full_gsn_n50_k3_p0.2_T3_ns3_s${s}.pt + echo "===== seed=$s full gsn =====" + python3 diag/train_color.py --mode train --grad_mode full --pe gsn --p 0.2 --epochs 150 --seed "$s" \ + || echo "!! train gsn s$s failed" + python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le gsn s$s failed" + python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm gsn s$s failed" +done +echo "===== AGGREGATE (all pe: none/rwse/gsn) =====" +python3 diag/aggregate.py +echo "done=$(date -Is)" diff --git a/diag/run_pe3.sh b/diag/run_pe3.sh new file mode 100644 index 0000000..a978393 --- /dev/null +++ b/diag/run_pe3.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Roadmap #3 (subgraph/ego features) + #4 (IGNN-style forced contraction), 5 seeds. +# #3: full --pe sub. #4: full --pe none --contract (vs free full/none baseline). +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" +for s in 0 1 2 3 4; do + ck=runs/ckpt_color_full_sub_n50_k3_p0.2_T3_ns3_s${s}.pt + echo "===== seed=$s #3 full sub =====" + python3 diag/train_color.py --mode train --grad_mode full --pe sub --p 0.2 --epochs 150 --seed "$s" || echo "!! train sub s$s failed" + python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le sub s$s failed" + python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm sub s$s failed" + + ck2=runs/ckpt_color_full_none_ctr_n50_k3_p0.2_T3_ns3_s${s}.pt + echo "===== seed=$s #4 full none --contract =====" + python3 diag/train_color.py --mode train --grad_mode full --pe none --contract --p 0.2 --epochs 150 --seed "$s" || echo "!! train ctr s$s failed" + python3 diag/train_color.py --mode le --ckpt "$ck2" || echo "!! le ctr s$s failed" + python3 diag/ptrm_color.py --ckpt "$ck2" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm ctr s$s failed" +done +echo "===== AGGREGATE =====" +python3 diag/aggregate.py +echo "done=$(date -Is)" diff --git a/diag/run_pna.sh b/diag/run_pna.sh new file mode 100644 index 0000000..897315d --- /dev/null +++ b/diag/run_pna.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "PNA start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)" +for s in 0 1 2 3 4; do + ck=runs/ckpt_color_pna_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt + echo "== pna s$s ==" + python3 diag/train_color.py --mode train --conv pna --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train pna s$s" + python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le pna s$s" + python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm pna s$s" +done +echo "donePNA=$(date -Is)" diff --git a/diag/run_ppgn.sh b/diag/run_ppgn.sh new file mode 100644 index 0000000..2cc27e9 --- /dev/null +++ b/diag/run_ppgn.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -uo pipefail +cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog +echo "PPGN start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)" +for s in 0 1 2 3 4; do + echo "== ppgn s$s =="; python3 diag/ppgn_color.py --grad_mode full --seed "$s" || echo "!! ppgn s$s" +done +echo "donePPGN=$(date -Is)" diff --git a/diag/run_real.sh b/diag/run_real.sh new file mode 100644 index 0000000..c29426a --- /dev/null +++ b/diag/run_real.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# Peptides-struct training diagnosis: depth sweep + RNI(noise/>1-WL) + GCN(<1-WL) reference. +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" +run() { echo "===== $* ====="; python3 diag/train_real.py "$@" --epochs 150 --seed 0 || echo "!! FAILED $*"; } +run --conv gin --layers 5 --rni 0 # baseline 1-WL +run --conv gin --layers 10 --rni 0 # depth / long-range +run --conv gin --layers 5 --rni 16 # noise / beyond-1-WL ceiling test +run --conv gcn --layers 5 --rni 0 # sub-1-WL reference +echo "done=$(date -Is)" diff --git a/diag/run_rec.sh b/diag/run_rec.sh new file mode 100644 index 0000000..632498d --- /dev/null +++ b/diag/run_rec.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +# Step-2: recursive GNN + full PTRM on ZINC ring-counting. Does per-step noise + best-Q@K +# selection break the 1-WL counting ceiling that input-RNI couldn't? Averaging vs selection. +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" +run() { echo "===== $* ====="; python3 diag/train_rec.py "$@" --epochs 200 --seed 0 || echo "!! FAILED $*"; } +run --sigma 0 --K 1 --select bestq # deterministic recursive baseline (~1-WL ceiling) +run --sigma 0.1 --K 8 --select none # averaging over rollouts (RNI-style null) +run --sigma 0.1 --K 8 --select bestq # PTRM-proper: per-step noise + best-Q@K selection +run --sigma 0.2 --K 16 --select bestq # scaled noise + rollouts +echo "done=$(date -Is)" diff --git a/diag/run_seeds.sh b/diag/run_seeds.sh new file mode 100644 index 0000000..77f22c8 --- /dev/null +++ b/diag/run_seeds.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Harden coloring results with 5 seeds: full-vs-1step solve, LE AUROC, PTRM(sigma=0.2) for none/rwse. +set -uo pipefail +cd /home/yurenh2/rrog +export PYTHONPATH=/home/yurenh2/rrog +echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)" +for s in 0 1 2 3 4; do + for cfg in "full none" "full rwse" "1step none"; do + set -- $cfg; gm=$1; pe=$2 + ck=runs/ckpt_color_${gm}_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt + echo "===== seed=$s $gm $pe =====" + python3 diag/train_color.py --mode train --grad_mode "$gm" --pe "$pe" --p 0.2 --epochs 150 --seed "$s" \ + || echo "!! train $gm $pe s$s failed" + python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le $gm $pe s$s failed" + if [ "$gm" = "full" ]; then + python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $pe s$s failed" + fi + done +done +echo "===== AGGREGATE =====" +python3 diag/aggregate.py +echo "done=$(date -Is)" diff --git a/diag/selftest_wl.py b/diag/selftest_wl.py new file mode 100644 index 0000000..6c08310 --- /dev/null +++ b/diag/selftest_wl.py @@ -0,0 +1,53 @@ +"""Validate the 1-WL instrument on canonical graphs BEFORE trusting any decomposition. +Run: PYTHONPATH=/home/yurenh2/rrog python3 /home/yurenh2/rrog/diag/selftest_wl.py +""" +import numpy as np +from diag import wl, datasets + + +def run(): + P = datasets.canonical_pairs() + names = list(P.keys()) + adjs = [wl.edges_to_adj(P[k]['n'], P[k]['edge_index']) for k in names] + node_rounds, ghist_rounds, conv = wl.wl_refine(adjs) + color = {names[i]: ghist_rounds[conv][i] for i in range(len(names))} + print("converged round:", conv) + for k in names: + print(f" {k:5s} tri={P[k]['tri']} wlcolor={color[k]}") + + # (1) C6 == 2C3 under 1-WL (both 2-regular) yet differ in triangles -> counting is H2 + assert color['C6'] == color['2C3'], "C6 vs 2C3 should be 1-WL-equal" + assert P['C6']['tri'] != P['2C3']['tri'] + # (2) P4 != K1,3 (different degree multiset) + assert color['P4'] != color['K1,3'], "P4 vs K1,3 should be 1-WL-distinct" + print("OK canonical: C6==2C3 (WL blind to triangles), P4!=K1,3") + + # (3) regression H2 floor on {C6,2C3} == target variance + sub = [wl.edges_to_adj(P['C6']['n'], P['C6']['edge_index']), + wl.edges_to_adj(P['2C3']['n'], P['2C3']['edge_index'])] + y = [P['C6']['tri'], P['2C3']['tri']] + _, gh, cv = wl.wl_refine(sub) + dec = wl.decompose_regression(gh, cv, L=10, y=y, train_idx=[0, 1], eval_idx=[0, 1]) + print(f" triangle-count H2 floor MSE on {{C6,2C3}} = {dec['mse_floor_oracle_H2']:.4f} " + f"(target var = {np.var(y):.4f})") + assert abs(dec['mse_floor_oracle_H2'] - np.var(y)) < 1e-9 + + # (4) CSL: 4-regular -> 1 node color, 1 graph color, WL-optimal acc = chance (0.1) -> 100% H2 + csl = datasets.build_csl(n_per_class=15, seed=0) + adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in csl] + nr, gh, cv = wl.wl_refine(adjs) + n_node_colors = len(set(nr[cv][0].tolist())) + n_graph_colors = len(set(gh[cv])) + y = [d['y'] for d in csl] + idx = list(range(len(csl))) + att = wl.attribute_classification(gh, cv, L=4, y=y, train_idx=idx, eval_idx=idx) + print(f" CSL: node-colors={n_node_colors}, distinct graph-colors={n_graph_colors}, " + f"WL-optimal acc={att['wl_optimal_acc_converged']:.3f} (chance 0.1), buckets={att['counts']}") + assert n_node_colors == 1 and n_graph_colors == 1 + assert abs(att['wl_optimal_acc_converged'] - 0.1) < 1e-6 + assert att['counts'].get('H2', 0) == len(csl) + print("OK CSL: fully 1-WL-collapsed -> 100% of failures are H2. Instrument VALIDATED.") + + +if __name__ == "__main__": + run() diff --git a/diag/train_color.py b/diag/train_color.py new file mode 100644 index 0000000..36f8496 --- /dev/null +++ b/diag/train_color.py @@ -0,0 +1,347 @@ +"""Recursive (TRM-ish) GNN graph 3-coloring with swappable BACKBONE for the RRoG roadmap. + +--conv gin|gcn|sage|gat|gps : message-passing operator (gps = GraphGPS local MPNN + global + attention = TRM's original transformer backbone, on the graph). +--pe none|rwse|gsn|sub|lappe|all : input structural features (random sym-break [+ encoding]). +--contract : reverse-flossing lambda-penalty during training (force contraction; roadmap #4). +--grad_mode full|1step : TRM full recursion vs HRM 1-step gradient. +Self-supervised conflict/Potts loss; success = zero-conflict; EMA; deep supervision. +Modes: --mode train (saves ckpt + JSON) / --mode le. +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from torch_geometric.nn import GINConv, GCNConv, SAGEConv, GATConv, GPSConv, PNAConv +from torch_geometric.utils import degree as _pyg_degree + +# GPS uses scaled_dot_product_attention; force the MATH kernel so torch.autograd.functional.jvp +# (LE diagnostic / PTRM rollouts) has a double-backward-able implementation. +for _f in ('enable_flash_sdp', 'enable_mem_efficient_sdp', 'enable_math_sdp'): + try: + getattr(torch.backends.cuda, _f)(_f == 'enable_math_sdp') + except Exception: + pass + +OUT = '/home/yurenh2/rrog/runs' +CACHE = '/home/yurenh2/rrog/data/color_cache' + + +def gen(n, k, p, r, seed): + rng = np.random.default_rng(seed) + part = rng.integers(0, k, n) + src, dst = [], [] + for i in range(n): + for j in range(i + 1, n): + if part[i] != part[j] and rng.random() < p: + src += [i, j]; dst += [j, i] + ei = torch.tensor([src, dst], dtype=torch.long) if src else torch.zeros((2, 0), dtype=torch.long) + rf = torch.tensor(rng.standard_normal((n, r)), dtype=torch.float) + return {'n': n, 'edge_index': ei, 'rfeat': rf} + + +def make_split(split, n, k, p, r, count, seed0): + os.makedirs(CACHE, exist_ok=True) + fp = os.path.join(CACHE, f"{split}_n{n}_k{k}_p{p}_r{r}.pt") + if os.path.exists(fp): + return torch.load(fp, weights_only=False) + data = [gen(n, k, p, r, seed0 + i) for i in range(count)] + torch.save(data, fp) + return data + + +def _adj(edge_index, n): + A = np.zeros((n, n), dtype=np.float64) + ei = edge_index.numpy() + if ei.shape[1]: + A[ei[0], ei[1]] = 1.0 + return np.maximum(A, A.T) + + +def rwse(edge_index, n, K): + A = _adj(edge_index, n); deg = A.sum(1) + P = A / np.where(deg > 0, deg, 1.0)[:, None] + out = np.zeros((n, K), dtype=np.float32); M = np.eye(n) + for j in range(K): + M = M @ P; out[:, j] = np.diag(M) + return torch.from_numpy(out) + + +def gsn_feats(edge_index, n): + A = _adj(edge_index, n); deg = A.sum(1) + tri = (A @ A @ A).diagonal() / 2.0 + wedge = deg * (deg - 1) / 2.0 + return torch.tensor(np.stack([np.log1p(tri), np.log1p(wedge)], axis=1), dtype=torch.float) + + +def sub_feats(edge_index, n): + A = _adj(edge_index, n); deg = A.sum(1); A2 = A @ A + M = (A2 > 0).astype(np.float64) * (A == 0).astype(np.float64); np.fill_diagonal(M, 0.0) + d2 = M.sum(1) + tri = (A @ A2).diagonal() / 2.0 + clus = np.where(deg > 1, tri / (deg * (deg - 1) / 2.0), 0.0) + return torch.tensor(np.stack([np.log1p(deg), np.log1p(d2), clus], axis=1), dtype=torch.float) + + +def lappe_feats(edge_index, n, kpe=8): + A = _adj(edge_index, n); deg = A.sum(1) + di = np.where(deg > 0, 1.0 / np.sqrt(deg), 0.0) + L = np.eye(n) - di[:, None] * A * di[None, :] + _, V = np.linalg.eigh(L) + pe = V[:, 1:kpe + 1] + if pe.shape[1] < kpe: + pe = np.pad(pe, ((0, 0), (0, kpe - pe.shape[1]))) + return torch.tensor(pe, dtype=torch.float) + + +def featurize(graphs, pe, rwse_k): + def feat(g): + ei, n = g['edge_index'], g['n'] + if pe == 'rwse': return rwse(ei, n, rwse_k) + if pe == 'gsn': return gsn_feats(ei, n) + if pe == 'sub': return sub_feats(ei, n) + if pe == 'lappe': return lappe_feats(ei, n) + if pe == 'all': return torch.cat([rwse(ei, n, rwse_k), gsn_feats(ei, n), sub_feats(ei, n)], dim=1) + return None + for g in graphs: + e = feat(g) + g['xin'] = torch.cat([g['rfeat'], e], dim=1) if e is not None else g['rfeat'] + return graphs + + +def deg_hist(graphs): + md = 0; ds = [] + for g in graphs: + d = _pyg_degree(g['edge_index'][1], g['n'], dtype=torch.long) + md = max(md, int(d.max()) if d.numel() else 0); ds.append(d) + h = torch.zeros(md + 1, dtype=torch.long) + for d in ds: + h += torch.bincount(d, minlength=md + 1) + return h + + +def make_conv(conv, hidden, deg=None): + if conv == 'gin': + return GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) + if conv == 'gcn': + return GCNConv(hidden, hidden) + if conv == 'sage': + return SAGEConv(hidden, hidden) + if conv == 'gat': + return GATConv(hidden, hidden, heads=4, concat=False, add_self_loops=True) + if conv == 'pna': + return PNAConv(hidden, hidden, aggregators=['mean', 'min', 'max', 'std'], + scalers=['identity', 'amplification', 'attenuation'], deg=deg, towers=1) + if conv == 'gps': + local = GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) + return GPSConv(hidden, local, heads=4) + raise ValueError(conv) + + +class RecGINColor(nn.Module): + def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full', sigma=0.0, conv='gin', deg=None): + super().__init__() + self.conv_type = conv + self.lin_in = nn.Linear(in_dim, hidden) + self.convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(inner)]) + self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)]) + self.head = nn.Linear(hidden, k) + self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma + + def block(self, z, ei, batch=None): + if self.conv_type == 'gps' and batch is None: + batch = z.new_zeros(z.size(0), dtype=torch.long) + for conv, bn in zip(self.convs, self.bns): + z = conv(z, ei, batch) if self.conv_type == 'gps' else conv(z, ei) + z = bn(z).relu() + return z + + def _inner(self, z, h0, ei, noise, batch): + z = self.block(z + h0, ei, batch) + if noise and self.sigma > 0: + z = z + self.sigma * torch.randn_like(z) + return z + + def recurse(self, z, h0, ei, noise, batch, one_step=False): + if one_step: + with torch.no_grad(): + for _ in range(self.T - 1): + z = self._inner(z, h0, ei, noise, batch) + z = z.detach() + return self._inner(z, h0, ei, noise, batch) + for _ in range(self.T): + z = self._inner(z, h0, ei, noise, batch) + return z + + def forward(self, xin, ei, batch=None, noise=False): + h0 = self.lin_in(xin) + z = torch.zeros_like(h0) + outs = [] + for s in range(self.n_sup): + z = self.recurse(z, h0, ei, noise, batch, one_step=(self.grad_mode == '1step')) + outs.append(self.head(z)) + z = z.detach() + return outs + + +def conflict_loss(logits, ei): + p = F.softmax(logits, dim=-1) + return (p[ei[0]] * p[ei[1]]).sum(-1).mean() + + +@torch.no_grad() +def solve_stats(model, recs, dev, sample=None): + model.eval() + solved = 0; conf = 0.0; tot = 0 + for r in (recs[:sample] if sample else recs): + ei = r['edge_index'].to(dev) + col = model(r['xin'].to(dev), ei)[-1].argmax(-1) + c = (col[ei[0]] == col[ei[1]]).sum().item() // 2 + solved += int(c == 0); conf += c; tot += 1 + return solved / tot, conf / tot + + +def lyap1(model, xin, ei, n_steps, dev, seed=0): + g = torch.Generator(device=dev).manual_seed(seed) + h0 = model.lin_in(xin).detach() + z = torch.zeros_like(h0) + v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12) + def step_fn(zz): + return model.block(zz + h0, ei) + lam = 0.0 + for _ in range(n_steps): + z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v) + z = z_next.detach(); nv = Jv.norm() + lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach() + return lam / n_steps + + +def run_le(model, recs, dev, n_steps, n_graphs=300): + try: + from sklearn.metrics import roc_auc_score + except Exception: + roc_auc_score = None + model.eval() + lams, fails = [], [] + for i, r in enumerate(recs[:n_graphs]): + ei = r['edge_index'].to(dev); xin = r['xin'].to(dev) + with torch.no_grad(): + col = model(xin, ei)[-1].argmax(-1) + c = (col[ei[0]] == col[ei[1]]).sum().item() + fails.append(int(c > 0)) + lams.append(lyap1(model, xin, ei, n_steps, dev, seed=i)) + lams, fails = np.array(lams), np.array(fails) + s, f = lams[fails == 0], lams[fails == 1] + auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) + sep = (f.mean() - s.mean()) if len(s) and len(f) else float('nan') + print(f"[{model.conv_type}/{model.grad_mode}] LE n={len(lams)} fail={fails.mean():.2f} | " + f"SOLVED {s.mean() if len(s) else float('nan'):+.3f} UNSOLVED {f.mean() if len(f) else float('nan'):+.3f}" + f" sep={sep:+.3f} AUROC={auc:.3f} mean_lam={lams.mean():+.3f}") + return {'n': int(len(lams)), 'fail_rate': float(fails.mean()), 'auroc': float(auc), 'sep': float(sep), + 'lam_solved': (float(s.mean()) if len(s) else None), + 'lam_unsolved': (float(f.mean()) if len(f) else None), 'mean_lam': float(lams.mean())} + + +def lyap_penalty(model, x, ei, batch, target=-0.5): + h0 = model.lin_in(x) + with torch.no_grad(): + zr = model.recurse(torch.zeros_like(h0), h0.detach(), ei, False, batch) + v = torch.randn_like(zr); v = v / (v.norm() + 1e-12) + _, Jv = torch.autograd.functional.jvp(lambda zz: model.block(zz + h0, ei, batch), zr, v, create_graph=True) + return (torch.log(Jv.norm() + 1e-12) - target) ** 2 + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--mode', choices=['train', 'le'], default='train') + ap.add_argument('--conv', choices=['gin', 'gcn', 'sage', 'gat', 'pna', 'gps'], default='gin') + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--pe', choices=['none', 'rwse', 'gsn', 'sub', 'lappe', 'all'], default='none') + ap.add_argument('--contract', action='store_true') + ap.add_argument('--rwse_k', type=int, default=16) + ap.add_argument('--ckpt', default=None) + ap.add_argument('--n', type=int, default=50); ap.add_argument('--k', type=int, default=3) + ap.add_argument('--p', type=float, default=0.2); ap.add_argument('--r', type=int, default=8) + ap.add_argument('--hidden', type=int, default=128); ap.add_argument('--T', type=int, default=3) + ap.add_argument('--n_sup', type=int, default=3); ap.add_argument('--epochs', type=int, default=150) + ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--bs', type=int, default=32) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(OUT, exist_ok=True) + + if args.mode == 'le': + ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg'] + te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), + c.get('pe', 'none'), c.get('rwse_k', 16)) + deg = torch.tensor(c['deg']) if c.get('deg') else None + model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'], + grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev) + model.load_state_dict(ck['state']); model.eval() + res = run_le(model, te, dev, c['n_sup'] * c['T']) + base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '') + with open(os.path.join(OUT, f"le_{base}.json"), 'w') as fjs: + json.dump({'conv': c.get('conv', 'gin'), 'grad_mode': c['grad_mode'], 'pe': c.get('pe', 'none'), + 'contract': c.get('contract', False), 'seed': c.get('seed'), **res}, fjs, indent=2) + return + + te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), args.pe, args.rwse_k) + tr = featurize(make_split('train', args.n, args.k, args.p, args.r, 2000, 0), args.pe, args.rwse_k) + in_dim = tr[0]['xin'].shape[1] + data = [Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr] + trl = DataLoader(data, batch_size=args.bs, shuffle=True, drop_last=True) + deg = deg_hist(tr) if args.conv == 'pna' else None + model = RecGINColor(in_dim, args.hidden, args.k, args.T, args.n_sup, + grad_mode=args.grad_mode, conv=args.conv, deg=deg).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + + ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + t0 = time.time(); best_solve = -1; best = {}; best_state = None + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + outs = model(b.x, b.edge_index, b.batch, noise=False) + loss = sum(conflict_loss(o, b.edge_index) for o in outs) / len(outs) + if args.contract: + loss = loss + lyap_penalty(model, b.x, b.edge_index, b.batch) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + with torch.no_grad(): + for kk, v in model.state_dict().items(): + if torch.is_floating_point(v): + ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) + else: + ema[kk].copy_(v.detach()) + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + backup = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + model.load_state_dict(ema) + sr, mc = solve_stats(model, te, dev, sample=300) + if sr > best_solve: + best_solve = sr; best = {'ep': ep + 1, 'solve_rate': round(sr, 4), 'mean_conflicts': round(mc, 3)} + best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema} + model.load_state_dict(backup) + print(f"ep{ep+1} solve_rate={sr:.3f} mean_conflicts={mc:.2f}", flush=True) + + sfx = ('_ctr' if args.contract else '') + tag = f"color_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}" + rep = {'task': 'graph3coloring', 'tag': tag, **vars(args), 'in_dim': in_dim, + 'sec': round(time.time() - t0, 1), **best} + print(f"[{tag}] best solve_rate={best.get('solve_rate')} @ep{best.get('ep')} ({rep['sec']}s)") + with open(os.path.join(OUT, f"{tag}.json"), 'w') as f: + json.dump(rep, f, indent=2) + torch.save({'state': best_state, 'cfg': {'in_dim': in_dim, 'hidden': args.hidden, 'k': args.k, + 'T': args.T, 'n_sup': args.n_sup, 'grad_mode': args.grad_mode, 'pe': args.pe, + 'rwse_k': args.rwse_k, 'contract': args.contract, 'conv': args.conv, 'seed': args.seed, + 'deg': (deg.tolist() if deg is not None else None)}}, + os.path.join(OUT, f"ckpt_{tag}.pt")) + print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt")) + + +if __name__ == "__main__": + main() diff --git a/diag/train_cycle.py b/diag/train_cycle.py new file mode 100644 index 0000000..d2342f3 --- /dev/null +++ b/diag/train_cycle.py @@ -0,0 +1,183 @@ +"""GNN-native ring-counting on real molecules (ZINC): regress [#5-cycles, #6-cycles]. + +k-cycle counts (k>=3) are provably NOT computable by 1-WL/MPNN (Chen et al. 2020) -> a REAL +H2 ceiling on REAL graphs. Training-based diagnosis (partition instrument is vacuous on +feature-rich graphs): + GIN(L) 1-WL baseline -> should FAIL to count + GCN(L) sub-1-WL reference + GIN+RNI random feats = NOISE -> PTRM-style crude symmetry break (eval-averaged) + GIN+RWSE random-walk return probs-> structured >1-WL positive control +Reads: GIN high error + RWSE fixes it = real ceiling exists; RNI also fixes = crude noise +breaks it (bridge cashed); only RWSE = bridge needs STRUCTURED stochasticity (GRAM>PTRM). +Targets z-scored for training; per-target MAE reported in RAW ring units. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_cycle.py --conv gin --feat none +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +import networkx as nx +from torch_geometric.datasets import ZINC +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from torch_geometric.utils import to_networkx +from torch_geometric.nn import GINConv, GCNConv, global_add_pool + +ROOT = '/home/yurenh2/rrog/data/zinc' +CACHE = '/home/yurenh2/rrog/data/cycle_cache' +OUT = '/home/yurenh2/rrog/runs' +RWSE_K = 16 + + +def rwse(edge_index, n, K=RWSE_K): + A = np.zeros((n, n), dtype=np.float64) + ei = edge_index.numpy() + A[ei[0], ei[1]] = 1.0 + A = np.maximum(A, A.T) + deg = A.sum(1) + P = A / np.where(deg > 0, deg, 1.0)[:, None] + out = np.zeros((n, K), dtype=np.float32) + M = np.eye(n) + for k in range(K): + M = M @ P + out[:, k] = np.diag(M) + return torch.from_numpy(out) + + +def c56(data): + G = to_networkx(data, to_undirected=True) + c = {5: 0, 6: 0} + for cyc in nx.simple_cycles(G, length_bound=6): + L = len(cyc) + if L in c: + c[L] += 1 + return [float(c[5]), float(c[6])] + + +def prepare(split): + os.makedirs(CACHE, exist_ok=True) + fp = os.path.join(CACHE, f"{split}.pt") + if os.path.exists(fp): + return torch.load(fp, weights_only=False) + ds = ZINC(ROOT, subset=True, split=split) + out = [] + for g in ds: + out.append({'x': g.x.view(-1).long(), 'edge_index': g.edge_index, + 'rwse': rwse(g.edge_index, g.num_nodes), + 'y': torch.tensor(c56(g), dtype=torch.float)}) + torch.save(out, fp) + return out + + +class Net(nn.Module): + def __init__(self, n_atom, hidden, layers, conv='gin', rni=0, use_rwse=False): + super().__init__() + self.emb = nn.Embedding(n_atom, hidden) + self.rni, self.use_rwse = rni, use_rwse + din = hidden + rni + (RWSE_K if use_rwse else 0) + self.lin_in = nn.Linear(din, hidden) + self.convs, self.bns = nn.ModuleList(), nn.ModuleList() + for _ in range(layers): + if conv == 'gin': + self.convs.append(GINConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)) + else: + self.convs.append(GCNConv(hidden, hidden)) + self.bns.append(nn.BatchNorm1d(hidden)) + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2)) + + def forward(self, x, edge_index, batch, rwse=None): + h = self.emb(x) + parts = [h] + if self.use_rwse: + parts.append(rwse) + if self.rni: + parts.append(torch.randn(h.size(0), self.rni, device=h.device)) + h = self.lin_in(torch.cat(parts, dim=1)) + for conv, bn in zip(self.convs, self.bns): + h = bn(conv(h, edge_index)).relu() + return self.head(global_add_pool(h, batch)) + + +def to_loader(recs, bs, shuffle, drop_last=False): + data = [Data(x=r['x'], edge_index=r['edge_index'], rwse=r['rwse'], + y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs] + return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last) + + +@torch.no_grad() +def eval_mae(model, loader, dev, ymu, ysd, samples=1): + model.eval(); abs_err = torch.zeros(2); n = 0 + for b in loader: + b = b.to(dev) + ps = torch.stack([model(b.x, b.edge_index, b.batch, b.rwse) for _ in range(samples)]).mean(0) + pr = ps * ysd.to(dev) + ymu.to(dev) # un-standardize -> raw ring units + yr = b.y * ysd.to(dev) + ymu.to(dev) + abs_err += (pr - yr).abs().sum(0).cpu(); n += b.num_graphs + return (abs_err / n).tolist() + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin') + ap.add_argument('--feat', choices=['none', 'rni', 'rwse'], default='none') + ap.add_argument('--layers', type=int, default=5) + ap.add_argument('--hidden', type=int, default=128) + ap.add_argument('--rni_dim', type=int, default=16) + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--bs', type=int, default=128) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(OUT, exist_ok=True) + + tr, va, te = prepare('train'), prepare('val'), prepare('test') + n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1 + Ytr = torch.stack([r['y'] for r in tr]) + ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 + for recs in (tr, va, te): + for r in recs: + r['y'] = (r['y'] - ymu) / ysd + + rni = args.rni_dim if args.feat == 'rni' else 0 + use_rwse = args.feat == 'rwse' + samples = 8 if rni else 1 + model = Net(n_atom, args.hidden, args.layers, args.conv, rni, use_rwse).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + lossf = nn.L1Loss() + trl = to_loader(tr, args.bs, True, drop_last=True) + trl_e, val, tel = to_loader(tr, 256, False), to_loader(va, 256, False), to_loader(te, 256, False) + + t0 = time.time(); best_val = 9e9; best = {} + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + loss = lossf(model(b.x, b.edge_index, b.batch, b.rwse), b.y) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + vm = eval_mae(model, val, dev, ymu, ysd, samples) + if sum(vm) < best_val: + best_val = sum(vm) + best = {'ep': ep + 1, 'train_mae': eval_mae(model, trl_e, dev, ymu, ysd, samples), + 'val_mae': vm, 'test_mae': eval_mae(model, tel, dev, ymu, ysd, samples)} + print(f"ep{ep+1} val_mae(c5,c6)={[round(x,3) for x in vm]}", flush=True) + + tag = f"{args.conv}_{args.feat}_L{args.layers}_s{args.seed}" + rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), + 'y_std_raw': ysd.tolist(), 'sec': round(time.time() - t0, 1), 'dev': dev, **best} + tm = best.get('test_mae'); trm = best.get('train_mae') + print(f"[{tag}] train_mae(c5,c6)={[round(x,3) for x in trm]} test_mae={[round(x,3) for x in tm]} " + f"(raw rings; std={ [round(x,2) for x in ysd.tolist()] }) @ep{best.get('ep')} ({rep['sec']}s)") + with open(os.path.join(OUT, f"cyc_{tag}.json"), 'w') as f: + json.dump(rep, f, indent=2) + print(" wrote", os.path.join(OUT, f"cyc_{tag}.json")) + + +if __name__ == "__main__": + main() diff --git a/diag/train_diag.py b/diag/train_diag.py new file mode 100644 index 0000000..a9c4ab2 --- /dev/null +++ b/diag/train_diag.py @@ -0,0 +1,161 @@ +"""Train a backbone, collect failures, attribute them via the 1-WL instrument. + +Node features are CONSTANT (all-ones) so the GIN starts from anonymous nodes -> its +expressivity ceiling is exactly the anonymous 1-WL partition the instrument computes +(wl_refine init = all-zero). GIN depth L == L WL rounds. Regression targets are +standardized (train stats) for stable training; all reported MSEs are in original units. +Train AND test metrics are reported so non-H2 error can be split into optimization +(can't even fit train) vs generalization (fits train, fails test). + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_diag.py --task csl --model gin +""" +import argparse, json, os, time +from collections import Counter +import numpy as np +import torch +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from diag import wl, datasets as DS, models as M + + +def to_pyg(raw, task, ymu=0.0, ysd=1.0): + out = [] + for d in raw: + x = torch.ones(d['n'], 1) + ei = torch.tensor(d['edge_index'], dtype=torch.long) + if task == 'clf': + y = torch.tensor([d['y']], dtype=torch.long) + else: + y = torch.tensor([[(d['y'] - ymu) / ysd]], dtype=torch.float) + out.append(Data(x=x, edge_index=ei, y=y, num_nodes=d['n'])) + return out + + +def split(n, frac, seed, y=None, stratify=False): + rng = np.random.default_rng(seed) + idx = np.arange(n) + if stratify and y is not None: + y = np.asarray(y); test = [] + for c in np.unique(y): + ci = idx[y == c]; rng.shuffle(ci) + test += ci[:max(1, int(round(frac * len(ci))))].tolist() + test = sorted(set(test)); train = [i for i in idx.tolist() if i not in set(test)] + else: + rng.shuffle(idx); k = int(frac * n) + test = sorted(idx[:k].tolist()); train = sorted(idx[k:].tolist()) + return train, test + + +@torch.no_grad() +def predict(model, loader, task, dev): + model.eval(); outs = [] + for b in loader: + b = b.to(dev) + o = model(b.x, b.edge_index, b.batch) + outs.append((o.argmax(1) if task == 'clf' else o.view(-1)).cpu()) + return torch.cat(outs).numpy() + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--task', choices=['csl', 'tri'], required=True) + ap.add_argument('--model', choices=['gin', 'gcn'], default='gin') + ap.add_argument('--layers', type=int, default=4) + ap.add_argument('--hidden', type=int, default=64) + ap.add_argument('--epochs', type=int, default=300) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--kind', default='er') + ap.add_argument('--out', default='/home/yurenh2/rrog/runs') + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(args.out, exist_ok=True) + + if args.task == 'csl': + raw = DS.build_csl(n_per_class=15, seed=args.seed); task, out_dim = 'clf', 10 + y = [d['y'] for d in raw]; tr, te = split(len(raw), 0.34, args.seed, y, stratify=True) + else: + raw = DS.build_triangle_count(n_graphs=800, n_nodes=18, kind=args.kind, deg=3, seed=args.seed) + task, out_dim = 'reg', 1; tr, te = split(len(raw), 0.3, args.seed) + + ymu, ysd = 0.0, 1.0 + if task == 'reg': + ytr = np.array([raw[i]['y'] for i in tr], dtype=np.float64) + ymu, ysd = float(ytr.mean()), float(ytr.std() + 1e-8) + + pyg = to_pyg(raw, task, ymu, ysd) + trl = DataLoader([pyg[i] for i in tr], batch_size=32, shuffle=True, drop_last=True) + alll = DataLoader(pyg, batch_size=64) + + Model = M.GIN if args.model == 'gin' else M.GCN + model = Model(in_dim=1, hidden=args.hidden, layers=args.layers, out_dim=out_dim).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3) + lossf = torch.nn.CrossEntropyLoss() if task == 'clf' else torch.nn.MSELoss() + + t0 = time.time() + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + o = model(b.x, b.edge_index, b.batch) + loss = lossf(o, b.y) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + + pred = predict(model, alll, task, dev) + if task == 'reg': + pred = pred * ysd + ymu + yv = np.array([d['y'] for d in raw], dtype=(np.float64 if task == 'reg' else np.int64)) + adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in raw] + _, ghist, conv = wl.wl_refine(adjs) + + rep = {'task': args.task, 'model': args.model, 'layers': args.layers, 'seed': args.seed, + 'kind': (args.kind if args.task == 'tri' else None), + 'n': len(raw), 'n_train': len(tr), 'n_test': len(te), 'conv_round': conv, + 'sec': round(time.time() - t0, 1), 'dev': dev} + + if task == 'clf': + test_pred, test_y = pred[te], yv[te] + acc = float((test_pred == test_y).mean()) + train_acc = float((pred[tr] == yv[tr]).mean()) + att = wl.attribute_classification(ghist, conv, args.layers, yv, tr, te) + fails = [te[k] for k in range(len(te)) if test_pred[k] != test_y[k]] + fb = Counter(att['buckets'][i] for i in fails) + rep.update({'train_acc': round(train_acc, 4), 'test_acc': round(acc, 4), + 'wl_ceiling_acc_converged': round(att['wl_optimal_acc_converged'], 4), + 'wl_ceiling_acc_Ldepth': round(att['wl_optimal_acc_Ldepth'], 4), + 'test_bucket_counts': att['counts'], + 'failure_bucket_counts': dict(fb), 'n_failures': len(fails)}) + print(f"[{args.task}/{args.model}] train_acc={train_acc:.3f} test_acc={acc:.3f} | " + f"1-WL ceiling(conv)={att['wl_optimal_acc_converged']:.3f} " + f"L-depth={att['wl_optimal_acc_Ldepth']:.3f} | failures={len(fails)} -> {dict(fb)}") + else: + test_pred, test_y = pred[te], yv[te] + mse = float(((test_pred - test_y) ** 2).mean()) + train_mse = float(((pred[tr] - yv[tr]) ** 2).mean()) + dec = wl.decompose_regression(ghist, conv, args.layers, yv, tr, te) + h2 = dec['mse_floor_oracle_H2'] + rep.update({'train_mse': round(train_mse, 4), 'test_mse_gin': round(mse, 4), + 'mse_floor_oracle_H2': round(h2, 4), + 'mse_floor_converged_train': round(dec['mse_floor_converged_train'], 4), + 'mse_floor_Ldepth_train': round(dec['mse_floor_Ldepth_train'], 4), + 'var_target_test': round(dec['var_target_eval'], 4), + 'frac_test_unseen_color': round(dec['frac_test_unseen_color'], 4), + 'frac_test_singleton_color': round(dec['frac_test_singleton_color'], 4), + 'learn_gap_test': round(max(0.0, mse - h2), 4)}) + print(f"[{args.task}/{args.model}/{args.kind}] train_mse={train_mse:.3f} test_mse={mse:.3f} | " + f"1-WL oracle floor(H2)={h2:.3f} | unseen={dec['frac_test_unseen_color']:.2f} " + f"singleton={dec['frac_test_singleton_color']:.2f} | learn_gap={max(0.0, mse - h2):.3f} " + f"var_y={dec['var_target_eval']:.3f}") + + tag = f"{args.task}_{args.kind}" if args.task == 'tri' else args.task + fn = os.path.join(args.out, f"diag_{tag}_{args.model}_L{args.layers}_s{args.seed}.json") + with open(fn, 'w') as f: + json.dump(rep, f, indent=2) + print(" wrote", fn) + + +if __name__ == "__main__": + main() diff --git a/diag/train_real.py b/diag/train_real.py new file mode 100644 index 0000000..336d86f --- /dev/null +++ b/diag/train_real.py @@ -0,0 +1,139 @@ +"""Training-based failure diagnosis on LRGB Peptides-struct (real, large, long-range). + +The WL partition instrument is vacuous here (graphs ~all distinguishable), so we diagnose +by TRAINING and comparing: + GIN(L) : standard 1-WL backbone at depth L + GIN(L)+RNI : random node features = noise = beyond-1-WL symmetry breaker + GCN(L) : sub-1-WL reference +Reads: deeper helps -> long-range/under-reaching; RNI helps -> a real >1-WL ceiling that +noise breaks; train<<test -> generalization; train high -> compute/optimization ceiling. +Targets z-scored per dim; metric = standardized MAE (lower better). 11 targets. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_real.py --conv gin --layers 5 --rni 0 +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +from torch_geometric.datasets import LRGBDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn import GINConv, GCNConv, global_mean_pool + +ROOT = '/home/yurenh2/rrog/data/lrgb' +OUT = '/home/yurenh2/rrog/runs' + + +class Net(nn.Module): + def __init__(self, col_sizes, hidden, layers, out_dim, conv='gin', rni=0): + super().__init__() + self.embs = nn.ModuleList([nn.Embedding(int(s), hidden) for s in col_sizes]) + self.rni = rni + self.lin_in = nn.Linear(hidden + rni, hidden) + self.convs, self.bns = nn.ModuleList(), nn.ModuleList() + for _ in range(layers): + if conv == 'gin': + mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)) + self.convs.append(GINConv(mlp, train_eps=True)) + else: + self.convs.append(GCNConv(hidden, hidden)) + self.bns.append(nn.BatchNorm1d(hidden)) + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) + + def forward(self, x, edge_index, batch): + h = sum(emb(x[:, i]) for i, emb in enumerate(self.embs)) + if self.rni: + h = torch.cat([h, torch.randn(h.size(0), self.rni, device=h.device)], dim=1) + h = self.lin_in(h) + for conv, bn in zip(self.convs, self.bns): + h = bn(conv(h, edge_index)).relu() + return self.head(global_mean_pool(h, batch)) + + +@torch.no_grad() +def mae(model, loader, dev, ymu, ysd): + model.eval(); se = n = 0.0 + for b in loader: + b = b.to(dev) + o = model(b.x, b.edge_index, b.batch) + se += (o - b.y).abs().sum().item(); n += b.y.numel() + return se / n + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin') + ap.add_argument('--layers', type=int, default=5) + ap.add_argument('--hidden', type=int, default=128) + ap.add_argument('--rni', type=int, default=0) + ap.add_argument('--epochs', type=int, default=150) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--bs', type=int, default=128) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(OUT, exist_ok=True) + + tr = LRGBDataset(root=ROOT, name='Peptides-struct', split='train') + va = LRGBDataset(root=ROOT, name='Peptides-struct', split='val') + te = LRGBDataset(root=ROOT, name='Peptides-struct', split='test') + + # per-column embedding sizes + target standardization (train stats) + col_max = None + Ytr = [] + for g in tr: + m = g.x.max(0).values + col_max = m if col_max is None else torch.maximum(col_max, m) + Ytr.append(g.y.view(-1)) + for ds in (va, te): + for g in ds: + col_max = torch.maximum(col_max, g.x.max(0).values) + col_sizes = (col_max + 2).tolist() + Ytr = torch.stack(Ytr) + ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 + + def norm(ds): + out = [] + for g in ds: + g = g.clone(); g.y = (g.y.view(1, -1) - ymu) / ysd + out.append(g) + return out + trl = DataLoader(norm(tr), batch_size=args.bs, shuffle=True, drop_last=True) + val = DataLoader(norm(va), batch_size=256) + tel = DataLoader(norm(te), batch_size=256) + trl_eval = DataLoader(norm(tr), batch_size=256) + + model = Net(col_sizes, args.hidden, args.layers, out_dim=11, conv=args.conv, rni=args.rni).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + lossf = nn.L1Loss() + + t0 = time.time(); best_val = 9e9; best = {} + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + loss = lossf(model(b.x, b.edge_index, b.batch), b.y) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + sched.step() + if (ep + 1) % 15 == 0 or ep == args.epochs - 1: + vm = mae(model, val, dev, ymu, ysd) + if vm < best_val: + best_val = vm + best = {'ep': ep + 1, 'train_mae': mae(model, trl_eval, dev, ymu, ysd), + 'val_mae': vm, 'test_mae': mae(model, tel, dev, ymu, ysd)} + print(f"ep{ep+1} val_mae={vm:.4f}", flush=True) + + tag = f"{args.conv}_L{args.layers}_rni{args.rni}_s{args.seed}" + rep = {'dataset': 'Peptides-struct', 'tag': tag, **vars(args), + 'sec': round(time.time() - t0, 1), 'dev': dev, **best} + print(f"[{tag}] train_mae={best.get('train_mae'):.4f} val_mae={best.get('val_mae'):.4f} " + f"test_mae={best.get('test_mae'):.4f} @ep{best.get('ep')} ({rep['sec']}s)") + fn = os.path.join(OUT, f"real_{tag}.json") + with open(fn, 'w') as f: + json.dump(rep, f, indent=2) + print(" wrote", fn) + + +if __name__ == "__main__": + main() diff --git a/diag/train_rec.py b/diag/train_rec.py new file mode 100644 index 0000000..9866f28 --- /dev/null +++ b/diag/train_rec.py @@ -0,0 +1,174 @@ +"""Step-2: Recursive (TRM-ish) GNN on ZINC ring-counting + optional PTRM noise/selection. + +Recurrent shared-weight GIN block, deep-supervised over n_sup steps (TRM-style: carry latent +detached between steps). --grad_mode controls the LAST supervision step's recursion: + full : backprop through all T inner recursions (TRM) + 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient) +Optional per-step Gaussian noise (sigma) + K stochastic rollouts selected by a value head +(best-Q@K) for the PTRM experiments. Saves a checkpoint for the LE diagnostic (diag/lyap.py). + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +from torch_geometric.loader import DataLoader +from torch_geometric.data import Data +from torch_geometric.nn import GINConv, global_add_pool +from diag.train_cycle import prepare + +OUT = '/home/yurenh2/rrog/runs' + + +def loader(recs, bs, shuffle, drop_last=False): + data = [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2), + num_nodes=r['x'].numel()) for r in recs] + return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last) + + +class RecGIN(nn.Module): + def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, grad_mode='full'): + super().__init__() + self.emb = nn.Embedding(n_atom, hidden) + self.convs = nn.ModuleList([GINConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) + for _ in range(inner)]) + self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)]) + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2)) + self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)) + self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode + + def block(self, z, ei): + for conv, bn in zip(self.convs, self.bns): + z = bn(conv(z, ei)).relu() + return z + + def _inner(self, z, h0, ei, noise): + z = self.block(z + h0, ei) + if noise and self.sigma > 0: + z = z + self.sigma * torch.randn_like(z) + return z + + def recurse(self, z, h0, ei, noise, one_step=False): + if one_step: # HRM 1-step gradient + with torch.no_grad(): + for _ in range(self.T - 1): + z = self._inner(z, h0, ei, noise) + z = z.detach() + return self._inner(z, h0, ei, noise) # only last inner carries grad + for _ in range(self.T): # TRM full recursion + z = self._inner(z, h0, ei, noise) + return z + + def forward(self, x, ei, batch, noise=False): + h0 = self.emb(x) + z = torch.zeros_like(h0) + preds = [] + for s in range(self.n_sup): + if s < self.n_sup - 1: + with torch.no_grad(): + z = self.recurse(z, h0, ei, noise) + z = z.detach() + else: + z = self.recurse(z, h0, ei, noise, one_step=(self.grad_mode == '1step')) + preds.append(self.head(global_add_pool(z, batch))) + q = self.qhead(global_add_pool(z, batch)).view(-1) + return preds, q + + +@torch.no_grad() +def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'): + model.eval() + ysd_d, ymu_d = ysd.to(dev), ymu.to(dev) + ae = torch.zeros(2); ae_or = torch.zeros(2); n = 0 + for b in ld: + b = b.to(dev) + if K == 1: + preds, _ = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) + chosen = oracle = preds[-1] + else: + P, Q = [], [] + for _ in range(K): + preds, q = model(b.x, b.edge_index, b.batch, noise=True) + P.append(preds[-1]); Q.append(q) + P = torch.stack(P); Q = torch.stack(Q) + ar = torch.arange(P.size(1), device=dev) + chosen = P[Q.argmax(0), ar] if select == 'bestq' else P.mean(0) + oracle = P[(P - b.y.unsqueeze(0)).abs().sum(-1).argmin(0), ar] + ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + ae_or += ((oracle * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + n += b.num_graphs + return (ae / n).tolist(), (ae_or / n).tolist() + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--sigma', type=float, default=0.0) + ap.add_argument('--K', type=int, default=1) + ap.add_argument('--select', choices=['none', 'bestq'], default='bestq') + ap.add_argument('--T', type=int, default=3) + ap.add_argument('--n_sup', type=int, default=3) + ap.add_argument('--hidden', type=int, default=128) + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--bs', type=int, default=128) + ap.add_argument('--lam_q', type=float, default=1.0) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(OUT, exist_ok=True) + + tr, va, te = prepare('train'), prepare('val'), prepare('test') + n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1 + Ytr = torch.stack([r['y'] for r in tr]); ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 + for recs in (tr, va, te): + for r in recs: + r['y'] = (r['y'] - ymu) / ysd + trl = loader(tr, args.bs, True, drop_last=True) + val, tel = loader(va, 256, False), loader(te, 256, False) + + model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + l1 = nn.L1Loss() + + t0 = time.time(); best_val = 9e9; best = {}; best_state = None + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) + loss = sum(l1(p, b.y) for p in preds) / len(preds) + with torch.no_grad(): + tq = -(preds[-1] - b.y).abs().mean(1) + loss = loss + args.lam_q * nn.functional.mse_loss(q, tq) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select) + if sum(vm) < best_val: + best_val = sum(vm) + tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select) + best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo} + best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True) + + tag = f"rec_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}_s{args.seed}" + rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1), + 'dev': dev, 'y_std_raw': ysd.tolist(), **best} + print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} " + f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)") + with open(os.path.join(OUT, f"{tag}.json"), 'w') as f: + json.dump(rep, f, indent=2) + torch.save({'state': best_state or model.state_dict(), + 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup, + 'sigma': args.sigma, 'grad_mode': args.grad_mode}, + 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt")) + print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt")) + + +if __name__ == "__main__": + main() diff --git a/diag/wl.py b/diag/wl.py new file mode 100644 index 0000000..26ab0a3 --- /dev/null +++ b/diag/wl.py @@ -0,0 +1,166 @@ +"""1-WL color-refinement instrument for diagnosing GNN failures (H1 vs H2). + +A GIN with L layers == L rounds of 1-WL refinement (injective sum aggregation). +A failure on sample i is attributed by label purity of its WL color classes: + + converged-WL class IMPURE (train labels conflict under same color) + -> H2 : 1-WL ceiling. No MPNN at ANY depth separates -> needs >1-WL (noise). + converged pure, but L-round class impure + -> H1a_depth : separable only with MORE rounds -> deterministic RR-on-graph / depth helps. + L-round class pure (info present at depth L) but model wrong + -> H1b_opt : optimization / capacity. Train better. + +Refinement is dataset-global (shared per-round signature->label map) so node colors and +graph-color histograms are comparable across graphs. +""" +from collections import Counter, defaultdict +import numpy as np + + +def edges_to_adj(n, edge_index): + adj = [[] for _ in range(n)] + ei = np.asarray(edge_index) + for a, b in zip(ei[0].tolist(), ei[1].tolist()): + adj[a].append(b) + return adj + + +def wl_refine(adjs, inits=None, max_rounds=None): + """Dataset-level 1-WL. Returns (node_rounds, ghist_rounds, conv_round). + node_rounds[r][g] = int color array (global labels) of graph g after r rounds. + ghist_rounds[r][g] = canonical color histogram (hashable) of graph g after r rounds. + conv_round = round index at which the global partition stabilized. + """ + if inits is None: + inits = [np.zeros(len(a), dtype=np.int64) for a in adjs] + else: + inits = [np.asarray(x, dtype=np.int64) for x in inits] + if max_rounds is None: + max_rounds = max((len(a) for a in adjs), default=0) + 2 + + d = {} + def lab(s): + v = d.get(s) + if v is None: + v = len(d); d[s] = v + return v + + cur = [np.array([lab(('i', int(c))) for c in init], dtype=np.int64) for init in inits] + node_rounds = [cur] + nclasses = [len(d)] + + for _r in range(max_rounds): + d = {} + nxt = [] + for adj in adjs: + c = cur_g = node_rounds[-1][len(nxt)] + arr = np.empty(len(adj), dtype=np.int64) + for v in range(len(adj)): + sig = (int(c[v]), tuple(sorted(int(c[u]) for u in adj[v]))) + arr[v] = lab(sig) + nxt.append(arr) + node_rounds.append(nxt) + nclasses.append(len(d)) + if nclasses[-1] == nclasses[-2]: # global #classes stopped growing -> converged + break + + conv_round = len(node_rounds) - 1 + ghist_rounds = [[_hist(c) for c in nr] for nr in node_rounds] + return node_rounds, ghist_rounds, conv_round + + +def _hist(colors): + return tuple(sorted(Counter(colors.tolist()).items())) + + +def graph_colors_at(ghist_rounds, conv_round, L): + return ghist_rounds[min(L, conv_round)] + + +# ---------- classification attribution ---------- +def attribute_classification(ghist_rounds, conv_round, L, y, train_idx, eval_idx): + y = np.asarray(y) + conv = ghist_rounds[conv_round] + Lr = min(L, conv_round) + lr = ghist_rounds[Lr] + conv_train, lr_train = defaultdict(list), defaultdict(list) + for i in train_idx: + conv_train[conv[i]].append(int(y[i])) + lr_train[lr[i]].append(int(y[i])) + + def pure(dct, key): + labs = dct.get(key) + return labs is not None and len(set(labs)) == 1 + + def majority(dct, key): + labs = dct.get(key) + return Counter(labs).most_common(1)[0][0] if labs else None + + buckets = {} + wl_opt = lr_opt = 0 + for i in eval_idx: + if conv[i] not in conv_train: + buckets[i] = 'novel' + elif not pure(conv_train, conv[i]): + buckets[i] = 'H2' + elif not pure(lr_train, lr[i]): + buckets[i] = 'H1a_depth' + else: + buckets[i] = 'H1b_opt' + if majority(conv_train, conv[i]) == int(y[i]): + wl_opt += 1 + if majority(lr_train, lr[i]) == int(y[i]): + lr_opt += 1 + n = len(eval_idx) + return { + 'buckets': buckets, + 'counts': dict(Counter(buckets.values())), + 'wl_optimal_acc_converged': wl_opt / n, # best ANY MPNN can do + 'wl_optimal_acc_Ldepth': lr_opt / n, # best L-layer MPNN can do + 'L_used': Lr, 'conv_round': conv_round, + } + + +# ---------- regression decomposition ---------- +def decompose_regression(ghist_rounds, conv_round, L, y, train_idx, eval_idx): + """H2 floor = ORACLE within-color variance on FULL data (best possible function of the WL + color: do same-color graphs share the target?). This is the true information ceiling and is + NOT confounded by train/test coverage. The train-fitted floors are also reported to expose + how much apparent error is really novel-color generalization, plus coverage fractions.""" + y = np.asarray(y, dtype=np.float64) + conv = ghist_rounds[conv_round] + Lr = min(L, conv_round) + lr = ghist_rounds[Lr] + full_idx = list(range(len(y))) + + # oracle: best constant per converged color over ALL data -> irreducible by any MPNN + conv_mean_full = _group_mean(conv, y, full_idx) + e_oracle = np.array([conv_mean_full[conv[i]] - y[i] for i in eval_idx]) + + # train-fitted (achievable with this split); fallback to global mean on unseen colors + conv_mean_tr = _group_mean(conv, y, train_idx) + lr_mean_tr = _group_mean(lr, y, train_idx) + gmean = float(y[list(train_idx)].mean()) + e_conv_tr = np.array([conv_mean_tr.get(conv[i], gmean) - y[i] for i in eval_idx]) + e_lr_tr = np.array([lr_mean_tr.get(lr[i], gmean) - y[i] for i in eval_idx]) + + conv_count = Counter(conv[i] for i in full_idx) + train_colors = set(conv[i] for i in train_idx) + frac_unseen = float(np.mean([conv[i] not in train_colors for i in eval_idx])) + frac_singleton = float(np.mean([conv_count[conv[i]] == 1 for i in eval_idx])) + return { + 'mse_floor_oracle_H2': float((e_oracle ** 2).mean()), # TRUE 1-WL ceiling + 'mse_floor_converged_train': float((e_conv_tr ** 2).mean()), + 'mse_floor_Ldepth_train': float((e_lr_tr ** 2).mean()), + 'frac_test_unseen_color': frac_unseen, + 'frac_test_singleton_color': frac_singleton, + 'L_used': Lr, 'conv_round': conv_round, + 'var_target_eval': float(y[list(eval_idx)].var()), + } + + +def _group_mean(colors, y, idx): + acc = defaultdict(list) + for i in idx: + acc[colors[i]].append(float(y[i])) + return {k: float(np.mean(v)) for k, v in acc.items()} diff --git a/papers/gram_2605.19376.pdf b/papers/gram_2605.19376.pdf Binary files differnew file mode 100644 index 0000000..b6637ef --- /dev/null +++ b/papers/gram_2605.19376.pdf diff --git a/papers/hrm_2506.21734.pdf b/papers/hrm_2506.21734.pdf Binary files differnew file mode 100644 index 0000000..9c87a62 --- /dev/null +++ b/papers/hrm_2506.21734.pdf diff --git a/papers/ptrm_2605.19943.pdf b/papers/ptrm_2605.19943.pdf Binary files differnew file mode 100644 index 0000000..ce01e6c --- /dev/null +++ b/papers/ptrm_2605.19943.pdf diff --git a/papers/ptrm_2605.19943.txt b/papers/ptrm_2605.19943.txt new file mode 100644 index 0000000..49213a0 --- /dev/null +++ b/papers/ptrm_2605.19943.txt @@ -0,0 +1,1418 @@ +Probabilistic Tiny Recursive Model + +Ali Parviz +Mila – Quebec AI Institute + +Alexia Jolicoeur-Martineau +Independent + +{amin.sghaier, ali.parviz}@mila.quebec +alexia.jolicoeur-martineau@mail.mcgill.ca + +Abstract +Tiny Recursive Models (TRM) solve complex reasoning tasks with a fraction of +the parameters of modern large language models (LLMs) by iteratively refining a +latent state and final answer. While powerful, their deterministic recursion can lead +to convergence at suboptimal solutions, without escape mechanism. A common +workaround relies on task-specific input perturbations at test time combined with +answer aggregation via voting. We introduce Probabilistic TRM (PTRM), a taskagnostic framework for test-time compute scaling that addresses this limitation +through stochastic exploration. PTRM injects Gaussian noise at each deep recursion +step, enabling parallel trajectories to explore diverse solution basins, and selects +among them using the model’s existing Q head (used for early stopping in the +original TRM). Without requiring retraining or task-specific augmentations, PTRM +enables substantial accuracy gains across benchmarks, including Sudoku-Extreme +(87.4% to 98.75%) and on various puzzles from Pencil Puzzle Bench (62.6% to +91.2%). On the latter, PTRM achieves nearly double the accuracy of frontier LLMs +(91.2% vs. 55.1%) at less than 0.0001x the cost, using only 7M parameters. + +PPBench Puzzles + +sudoku, lightup, nurikabe, heyawake, and tapa + +91.2 + +80 +55.1 +34.7 + +Direct prediction +Deterministic recursive prediction + +0 + +0 + +PTRM (ours) + +0 + +TRM + +PTRM (ours) + +TRM + +LLM ensemble + +claude-opus-4-6 + +Chain-of-thought, pretrained +LLM ensemble + +0 + +HRM + +20 + +Direct pred + +40 +0 + +gemini-3.1-pro + +Direct pred + +0 + +24.5 + +98.75 + +55 + +60 + +2 +gpt-5.2@xhigh + +20 + +24.5 + +80 + +62.6 + +o3-mini-high + +40 + +87.4 + +Claude 3.7 8K + +60 + +Sudoku-Extreme +100 + +Deepseek R1 + +100 + +Accuracy (%) + +arXiv:2605.19943v1 [cs.AI] 19 May 2026 + +Amin Sghaier +Mila – Quebec AI Institute +ILLS & ETS Montreal + +Probabilistic recursive prediction (ours) + +Best of 7 strongest LLMs. Assumes access to a perfect verifier. + +Figure 1: PTRM performance comparison. On various PPBench puzzles, PTRM boosts TRM +performance by 28.6 points without any retraining. It outperforms the strongest single frontier LLMs +by 56.5 points and an ensemble of the seven strongest LLMs (assuming a perfect verifier) by 36 +points. On Sudoku-Extreme, PTRM reaches a state of the art 98.75%. + +1 + +Introduction + +Tiny Recursive Models (TRM) [1] achieve strong performance on complex reasoning puzzles with +orders of magnitude fewer parameters than the large language models (LLMs) they outperform on +tasks like Sudoku-Extreme [2] and ARC-AGI [3, 4]. TRM and its predecessor Hierarchical Reasoning +Model (HRM) [2] represent an emerging architectural alternative to standard autoregressive reasoning +models. Rather than autoregressively generating chains of token-level reasoning, they recursively +refine a latent state. This approach produces a single deterministic answer per input, fitting well with +tasks where the answer is unique. +Despite their strong performance, their deterministic inference does not make full use of their +capabilities. We show that many of TRM’s incorrect answers are from rollouts trapped in bad latent +space basins (i.e., regions of the latent space which decode to incorrect answers and from which the +deterministic recursions cannot escape). This observation, which aligns with recent mechanistic work +on related models [5], suggests that TRM has the capabilities to solve significantly more problems +but is limited by its standard inference procedure. +Although each puzzle has a unique correct answer, many distinct latent trajectories can reach it. This +is analogous to reasoning LLMs, where many reasoning trajectories can lead to the same unique +answer. However, being non-deterministic, LLMs can be randomly sampled in order to form different +trajectories (including Chains of Thought and actual answer). By then selecting a trajectory using +a voting mechanism or based on the answer’s projected value (via a verifier), LLMs can leverage +test-time compute to achieve very high accuracy [6]. We propose a way to achieve similar test-time +scaling performance gains by sampling stochastic latent trajectories, each producing a deterministic +decoded answer, and selecting among the answers using the model’s own Q head. +TRM’s Q head is trained jointly (as a correctness classifier) with the rest of the network and is +conventionally used only at training time for adaptive computation (ACT) [7]. It carries valuable +information that the standard inference procedure discards. +We propose Probabilistic TRM (PTRM), a test-time compute scaling framework that introduces a +new width scaling axis. At inference we run K parallel rollouts per puzzle, each receiving Gaussian +noise injected into the latent at every deep recursion step. The noise causes rollouts to follow different +latent trajectories and settle in different basins. Among the resulting candidate answers, the Q head +is used to select the one most likely to be correct. PTRM requires no training changes and no +task-specific test-time augmentation, yet, as illustrated in Figure 1, delivers substantial accuracy +gains across diverse reasoning benchmarks. + +2 + +Background: Tiny Recursive Model + +Tiny Recursive Model (TRM) is a single network that iteratively refines a predicted answer y to a +question x through recursive updates of a reasoning latent z. Specifically, a single latent recursion +consists of n updates to the latent state z followed by one update to the predicted answer y, all using +the same two-layer network fθ : z ← fθ (x + y + z) n times, then y ← fθ (y + z). +fθ distinguishes the two update types by whether the input includes x. A deep recursion runs T +latent recursions in sequence, with only the final one retaining gradients, allowing the model to +leverage a large effective depth while keeping training efficient. +Rather than doing one optimization step per sample, TRM is trained via deep supervision, which +consists in keeping the previous latent state z and answer y as initialization (after being detached from +the computational graph) for the next supervision step. This is done for up to Nsup supervision steps. +The loss at each step is calculated using cross entropy between the predicted answer logits fO (y) +(where fO is a linear output head) and the ground truth ytrue . This trains the network to progressively +refine its prediction across reasoning steps. At inference, the recurrence can be unrolled for more +steps than during training, providing a depth axis for test-time compute scaling (additional steps may +correct otherwise-incorrect answers). +Without halting mechanism during training, each puzzle stays in the mini-batch for Nsup supervision +steps rather than being replaced after each one. To avoid wasting compute on already-solved samples, +an Adaptive Computational Time (ACT) halting mechanism is used. This is done by adding a binary +cross entropy loss between a halting logit q̂ = fQ (y) (where fQ is a linear Q head) and the binary exact +2 + +Correct answer + +Incorrect answer + +PC 1 (58% var) +1.0 + +5.0 +0.9 2.5 +0.0 +0.8 2.5 +5.0 +0.7 + +2.5 +0.0 +2.5 +5 + +10 + +Supervision step + +15 + +0.9 + +1.0 + +5 + +Cell accuracy + +Q value + +PC 1 (85% var) + +1.0 + +5.0 + +0 + +Failure + +End + +PC 2 (8% var) + +PC 2 (36% var) +PC 1 (84% var) + +5.0 + +Start + +Delayed success + +PC 2 (15% var) + +Quick success + +Cell accuracy + +0.8 + +0 + +0.8 +0.7 +0 + +5 + +10 + +Supervision step + +15 + +0.6 + +5 +0 + +5 + +10 + +Supervision step + +15 + +Figure 2: TRM Trajectory Modes. PCA projection of y (top) and Q value (solid, left axis) with cell +accuracy (dashed, right axis) across supervision steps (bottom) for three PPBench puzzles, illustrating +three trajectory modes (left to right): quick success, delayed success, and failure (Sec. 3). Latents are +projected into the principal plane per puzzle, so PC axes are not comparable across plots. Trajectories +fade from light (early steps) to dark (later steps). Circle marks the start and square marks end. + +correctness of the predicted answer ŷ = arg max fO (y): Lstep = CE(fO (y), ytrue ) + BCE(q̂, 1[ŷ = +ytrue ]). The Q head thus allows the supervision loop to halt early on samples where sigmoid(q̂) > 0.5, +improving data efficiency. During inference, the Q head is not used, and the model performs Nsup +supervision steps to maximize answer correctness. +While TRM is powerful, it sometimes gets stuck into incorrect solutions. In the next section, we will +investigate such failures cases in order to determine a way to remedy them. + +3 + +Problem: When Does TRM Fail? + +3.1 + +Analysis of failures and successes + +We present observations about TRM that motivate our method. In this section, we train a TRM on +multiple Pencil Puzzle Bench (PPBench) [8] puzzles and inspect the latent dynamics and Q head +behavior across supervision steps on a held-out validation set. For each puzzle, we record the latent +yt and the Q logit q̂t = fQ (yt ) at every supervision step t = 1, . . . , Nsup , project the latents into +the principal plane (PCA per puzzle), and jointly plot the Q value alongside cell accuracy (fraction +of correct cells in the predicted answer) over supervision steps. Figure 2 shows paired PCA and +Q/cell-accuracy plots for three representative puzzles, illustrating three trajectory modes we observe: +Quick success: the trajectory transitions in a few steps from its starting location to a convergence +region and remains there. Cell accuracy and the Q value rise together and saturate near their maxima +within the same few steps. +Delayed success: the trajectory initially oscillates around one region and remains there for multiple +supervision steps before sharply escaping to a different region where it converges. During the initial +3 + +phase, the Q value is negative, and at the step where the trajectory escapes, both Q value and cell +accuracy spike together. +Failure: the trajectory oscillates in a bounded region without converging. Cell accuracy never reaches +near 100%, and the Q value stays negative for all supervision steps. +We refer to latent space regions that trajectories remain in across multiple supervision steps and +exhibit similar cell accuracy throughout as basins. Basins where cell accuracy is near-maximal are +good basins and basins where it is not are bad basins. Initially, failures and delayed successes behave +similarly (both are caught in bad basins with negative Q). They diverge only later in their trajectories, +when delayed successes find an escape to a good basin while failures remain stuck. +3.2 + +The Q head tracks trajectory quality +6 +4 + +Cell accuracy + +Q value + +2 + +1.00 +0.95 +0.90 +0.85 +Incorrect (28) +Correct (69) +0.80 +Cell accuracy (right axis) +0.75 +0.70 +0.65 +0.60 +10 +12 +14 + +0 +2 +4 +6 +0 + +2 + +4 + +6 +8 +Supervision step + +Figure 3: Q value follows cell accuracy across reasoning. Mean +Q value (solid, left axis) and mean +cell accuracy (dashed, right axis) +over supervision steps, aggregated +over 100 PPBench validation puzzles, separated by final correctness +(green: correct, red: incorrect). + +Across all three modes (failures, delayed successes, and quick successes), we find that the Q head’s +value closely tracks cell accuracy at every supervision step. To further confirm this, Figure 3 +aggregates trajectories from 100 PPBench validation puzzles, separating them by final-answer +correctness. The aggregate view corroborates the per-puzzle observation: mean Q and mean cell +accuracy rise together on correct trajectories and remain mostly flat on incorrect ones. Moreover, at +convergence, the Q logit sharply separates the two populations where q̂ ≈ +6 (sigmoid ≈ 1) for +correct trajectories and q̂ ≈ −6 (sigmoid ≈ 0) for incorrect ones. The Q head is therefore a reliable +learned indicator of whether a trajectory has reached a good basin. +Given that the Q head’s ability to distinguish good from bad trajectories, a natural question follows: +can we leverage the Q head to identify better trajectories? The main challenge is that the standard +TRM is inherently deterministic, and thus cannot be used to sample different trajectories for a given +problem. In the next section, we will show that by simply adding Gaussian noise to the latent state, +we can sample different parallel trajectories and leverage the Q head to pick the best one. + +4 + +Method: Test-Time Compute Scaling via Stochastic Rollouts + +We propose Probabilistic TRM (PTRM), an inference-time procedure that makes the TRM recursion +stochastic and selects the best of K resulting trajectories. PTRM requires no special training and +can be readily applied to any pretrained TRM model. Furthermore it requires no task-specific +augmentations. PTRM works as follows: at each supervision step, we add Gaussian noise (scaled by +σ) to the latent state input. The Q head fQ scores each candidate latent output, and the one with the +highest Q value is selected and then decoded using the model’s output head fO . The algorithm in +Figure 4 (left) states this formally. PTRM offers two complementary benefits: 1) it enables trajectories +to escape bad basins where deterministic TRM remains stuck, and 2) it introduces width as a new +axis for test-time scaling. +4.1 + +Escaping bad basins + +In Sec. 3, we found that some failed deterministic trajectories are caught in bad solution basins in +latent space, with no way to escape. PTRM lets us test whether stochastic perturbations are enough +for some of the rollouts of a previously failed puzzle to reach a good solution basin. Figure 5 shows +K=100 independent rollouts, from the same failed puzzle used in Figure 2 (which fails at K=1), +4 + +PTRM Inference + +(a) Standard TRM (deterministic) + +1: Input: puzzle x, rollouts K, +2: supervision steps D, noise scale σ +3: for k = 1, . . . , K in parallel do + +answer + +depth axis: D deep recursion steps + +(k) + +(b) PTRM (ours): K stochastic rollouts + Q-head selection ++ϵ + +width axis: K rollouts + +(k) + +Initialize z0 , y0 +for t = 1, . . . , D do +(k) +zt−1 += ϵ, ϵ ∼ N (0, σ 2 I) +(k) +(k) +(k) +(k) +7: +zt , yt ← rec(x, zt−1 , yt−1 ) +8: +end for +(k) +9: +ŷ (k) ← arg max fO (yD ) +(k) +(k) +10: +q̂ ← fQ (yD ) +11: end for ∗ +12: return ŷ (k ) , k ∗ = arg maxk q̂ (k) +4: +5: +6: + +··· + +puzzle + +k= + +puzzle + +1 + ++ϵ + ++ϵ + +··· ++ϵ + ++ϵ + ++ϵ + +··· +k=2 +k= +K + +· +· +· + ++ϵ + +· +· +· + +· +· +· + ++ϵ + ++ϵ + +··· +deep recursion step + ++ϵ + +arg maxk Qk + +final answer + +Gaussian noise injection + +Figure 4: Left: PTRM inference procedure (the rec() function refers to a deep recursion step). Right: +PTRM mechanism. (a) Standard TRM: a single deterministic rollout. (b) PTRM: K stochastic latent +rollouts with Gaussian noise ϵ at each deep recursion step, with the Q head selecting the final answer. +projected into the principal plane. Most rollouts (92%) remain stuck in the same bad basin, while +a minority (8%) escape to a distinct region in latent space and produce correct answers. We also +observe that recurrent noise creates a per-rollout probability of escape: at K = 5 no rollouts escape, +at K = 25 one does, and at K = 100 eight do. This confirms that noise provides the stochasticity +needed to occasionally find an escape trajectory. +4.2 + +Width scaling + +Since more rollouts per puzzle compound the chance that at least one reaches a good basin, the +number of rollouts K is a natural quantity to scale. Given K independent rollouts, pass@K (any +rollout correct) is the oracle upper bound and best-Q@K (the rollout with highest q̂ is correct) is a +metric available at inference without a correctness oracle. The choice of Q as selector is motivated by +Sec. 3’s observation that Q accurately separates correct from incorrect trajectories (Figure 3). +Figure 6 shows pass@K and best-Q@K as K grows, averaged over 3 seeds on the held-out PPBench +validation set (sudoku, nurikabe, tapa, lightup, and heyawake). Both metrics rise from 76.4% at +K = 1 to 89.5% at K = 100, a gain of 13 percentage points. Across all tested K, the gap between +pass@K and best-Q@K stays under 1pp, making the Q head a strong verifier on this validation set. +By contrast, mode@K (most frequent answer across rollouts) rises by only 1.3pp over the same +range, showing that the width-scaling gains come mostly from the Q head’s ability to identify correct +solutions even when they are rare. +Interaction with depth scaling. Depth is another scaling axis already supported by TRM, which +consists of running more deep recursions (supervision steps) at inference than the Nsup the model +was trained on. On the deterministic baseline (K=1), tripling the depth from 16 to 48 steps raises +PPBench validation accuracy from 76.4% to 79.5% (+3.1pp). At higher K, depth scaling only +provides additional gains on specific puzzle types such as sudoku (+4pp at K = 100). Both depth +and width scaling can be seen as ways to explore the model’s solution space. Since rollouts are +independent and parallelizable while extra depth is sequential, width is the more practical scaling +axis. +PTRM unlocks a simple and task-agnostic recipe for scaling TRM test-time compute. The next +section evaluates the method across multiple benchmarks and against several baselines, including +frontier LLMs. + +5 + +Experiments + +This section evaluates PTRM’s performance on diverse reasoning benchmarks. We compare against +the deterministic TRM baseline, a non-recursive direct-prediction baseline, and frontier LLMs. +Across several PPBench puzzles [8], Sudoku-Extreme [2], Maze-Hard [2], and ARC-AGI 2 [4], +PTRM substantially boosts the performance of each pretrained TRM using only inference compute. +5 + +Correct (8) +Incorrect (92) +Start +End + +10 +8 + +92.5 + +PPBench accuracy (%) + +PC 2 (34% var) + +6 +4 +2 +0 + +85.0 +82.5 +80.0 +77.5 +72.5 + +2.5 + +0.0 + +2.5 +5.0 +7.5 +PC 1 (53% var) + +10.0 + +12.5 + +Figure 5: Stochastic rollouts escape bad +basins. Principal plane projection of K = +100 independent rollouts of the same failed +puzzle as in Figure 2 (right). 92 rollouts +remain caught in the bad basin (red). 8 +escape to a good basin and produce correct +answers (green). + +5.1 + +87.5 + +75.0 + +2 +4 + +pass@K +best-Q@K +mode@K + +90.0 + +1 + +5 +10 +25 +Rollouts per puzzle K (log scale) + +100 + +Figure 6: Width scaling. pass@K, best-Q@K, +and mode@K as K grows, averaged over 3 +seeds on a held-out PPBench validation set. The +Q head is a strong verifier on the tested puzzles, +consistently outperforming selection of the most +frequent answer. + +Setup + +Datasets. Pencil Puzzle Bench (PPBench) [8] consists of 62,231 constraint-satisfaction pencil puzzles +(from 94 puzzle types). From the full PPBench dataset, 300 puzzles (15 puzzles from 20 types) +selected by Waugh [8] are held out to form the golden set. From the remainder we hold out a +fixed-size validation set of 100 puzzles per puzzle type (50 for tapa, due to its smaller base size), +and the rest forms the training set. We filter all three sets to puzzles of six types (sudoku, lightup, +nurikabe, shakashaka, heyawake, and tapa) of grid size 9×9 for sudoku, and 10×10 for the rest. +We use the validation set to track performance during training and select the final checkpoint. We +report per-puzzle accuracy on five of these types on the golden set (TRM already reaches 100% on +shakashaka, so we omit it from the reported results), with aggregate scores sample-weighted across +types. We also report results on the Sudoku-Extreme, Maze-Hard, and ARC-AGI 2 datasets. +Models and inference. For each benchmark we use a standard TRM checkpoint. For SudokuExtreme we use the TRM-MLP variant (which the TRM paper showed to be stronger on Sudoku), +and for the other datasets, we use TRM-Att. PTRM inference uses K parallel rollouts each running +D supervision steps with Gaussian noise of scale σ added to the latent state at each supervision step. +The selected configuration (K, D, σ) varies by benchmark and is given alongside each result. Metrics +are averaged across three seeds. +Baselines. To isolate the contribution of PTRM’s stochastic rollouts from the underlying backbone, +we report standard TRM performance (the same checkpoint as PTRM ran deterministically). For +each dataset, we report the performance of frontier LLMs. For Sudoku-Extreme, Maze-Hard, and +ARC2 we additionally report the published direct prediction and TRM baselines from [1]. +Cost estimation. PPBench provides the dollar cost per attempt for each LLM. We convert PTRM’s +wall-clock to a comparable dollar figure using a single H100 at $2.50/hr (standard cloud pricing [9]) +so that cost = $2.50 · tpuzzle /3600, where tpuzzle is the time (in seconds) to complete a puzzle. +5.2 +5.2.1 + +Pencil Puzzle Bench +Per-puzzle accuracy + +Table 1 reports per-puzzle accuracy on the PPBench golden set. PTRM at K=100, D=48, σ=0.2 +raises aggregate best-Q@K from 62.6% to 91.2%. Increasing supervision depth alone (K=1, D=48) +gives a small boost over the standard TRM baseline (K=1, D=16). Most of the gain comes +from scaling width (stochastic rollouts). The largest improvements are on puzzle types where +6 + +the deterministic baseline performed the worst (most headroom): sudoku improves from 46.7% to +97.8% and tapa from 40.0% to 80.0%. +% accuracy +Direct prediction +TRM (K=1, D=16) +TRM (K=1, D=48) +PTRM, best-Q@K (K=100, D=16) +PTRM, best-Q@K (K=100, D=48) + +# Params sudoku lightup nurikabe heyawake +27M +7M +7M +7M +7M + +0.0 +46.7 +57.8 +93.3 +97.8 + +0.0 +87.5 +87.5 +100 +100 + +0.0 +74.1 +74.1 +88.9 +88.9 + +14.3 +85.7 +85.7 +85.7 +85.7 + +tapa + +agg. + +0.0 +2.0 +40.0 62.6 +40.0 66.0 +80.0 89.8 +80.0 91.2 + +Table 1: PPBench per-puzzle accuracy on the golden set. PTRM uses the same backbone as +the deterministic TRM. Scaling depth alone (K=1, D=48) lifts aggregate accuracy by 3.4 points +over the standard D=16 baseline. Combining depth with K=100 stochastic (σ=0.2) rollouts raises +accuracy by 28.6 percentage points overall. The direct-prediction baseline is a larger transformer +trained on the same data. + +5.2.2 + +Comparison with frontier LLMs on golden set + +PPBench reported per-puzzle results for several frontier LLMs using two strategies: 1) direct response +from a single prompt, and 2) multi-turn agentic strategy with verification. We report results for direct +and any (best of any strategy attempted, including agentic). The agentic strategy gives the LLM +substantially more resources than PTRM has access to. It provides the LLM the ability to iteratively +verify each move with a perfect verifier. The direct strategy is the fairer comparison since, while +it may use the model provider’s reasoning harness, it does not have direct access to a multi-turn +verifier (the LLM could still self-verify by writing verification code within the same response). We +additionally observe that the agentic strategy was applied selectively in the published PPBench data: +across the LLMs we compare against, only 9.6% of direct failures on the golden set were retried +with agentic. We restrict the comparison to the 7 strongest LLMs that attempted every puzzle in our +golden set: claude-opus-4-6@thinking, gpt-5.2@xhigh, gemini-3.1-pro, gpt-5.2@high, +claude-sonnet-4-6@thinking, gpt-5.2@medium, and kimi-k2.5. Table 2 lists the top 3 in +each strategy block. +We additionally report an ensemble score formed from these 7 LLMs where a puzzle counts as solved +if at least one of them solved it via any strategy. This ensemble setup is deliberately stacked against +PTRM. It assumes a perfect verifier since, if any of the 7 LLMs produced a correct answer under +any strategy, the ensemble counts it as solved, even though in practice we would not have access +to an oracle verifier. Although it is not deployable, we include the ensemble to demonstrate that +even under these heavily favorable conditions, frontier LLMs fall well short of PTRM. Ensemble +cost-per-attempt averages over the attempts of all 7 models on each puzzle, and cost-per-correct +divides total cost by the number of puzzles the ensemble solved. +Table 2 reports the comparison. PTRM exceeds the strongest single LLM (direct strategy) by 57 +points aggregate (91.2% vs. 34.7%), and exceeds the LLM ensemble by 36 points (91.2% vs. 55.1%) +despite the ensemble’s stacked advantages. Cost per attempt is several orders of magnitude higher for +LLMs than PTRM. +5.3 + +Sudoku-Extreme, Maze-Hard, and ARC-AGI-2 + +For each benchmark we use the standard TRM checkpoint trained as described in [1] without +modification (TRM-MLP for Sudoku-Extreme and TRM-Att for Maze-Hard and ARC-AGI-2). +Table 3 summarizes results on all three. +On Sudoku-Extreme, PTRM at K=100, D=64, σ=0.3 raises the deterministic baseline of 87.3% to +99.06% pass@K and 98.75% best-Q@K, achieving state of the art. +On Maze-Hard, PTRM at K=100, D=16, σ=1.0 reaches 95.63% pass@K, an 11.83 point gain +over the 83.8% deterministic baseline. mode@K gives the best PTRM accuracy here at 86.73% +(+2.93 points), with best-Q@K slightly behind at 85.17% (+1.37 points). While pass@K shows +that PTRM is able to unlock several correct answers, the Q head identifies them less reliably than on +the previous benchmarks. +7 + +% accuracy + +tapa + +agg. + +$/att. + +$/corr. + +30.0 +50.0 +60.0 + +24.5 +24.5 +34.7 + +$0.40 +$1.79 +$2.91 + +$1.62 +$7.29 +$8.40 + +0.0 +0.0 +0.0 + +40.0 +60.0 +70.0 + +30.6 +34.7 +36.7 + +$10.38 +$3.09 +$4.38 + +$33.91 +$8.90 +$11.92 + +0.0 + +80.0 + +55.1 + +$2.66 + +$38.51 + +sudoku lightup nurikabe heyawake +Direct + +gemini-3.1-pro +gpt-5.2@xhigh +claude-opus-4-6@thinking + +6.7 +20.0 +0.0 + +75.0 +50.0 +87.5 + +22.2 +0.0 +44.4 + +0.0 +0.0 +0.0 + +Any strategy (direct or agentic)† +gemini-3.1-pro +gpt-5.2@xhigh +claude-opus-4-6@thinking + +6.7 +33.3 +0.0 + +87.5 +75.0 +87.5 + +33.3 +0.0 +44.4 + +LLM ensemble† +Any strategy (direct or agentic) + +46.7 + +100 + +44.4 + +Ours, trained from scratch, 7M parameters +PTRM, best-Q@K + +97.8 + +100 + +88.9 + +85.7 + +80.0 91.2 $0.001 $0.001 + +Table 2: PTRM vs. frontier LLMs on PPBench golden. Per-puzzle accuracy and per-attempt / +per-correct cost on the golden set. LLM costs are from PPBench. PTRM cost is estimated from H100 +wall-clock (Sec. 5.1). The direct and agentic blocks list the 3 highest scoring LLMs on aggregate, +and the ensemble row uses all 7 listed in Sec. 5.2.2. † Assumes access to a perfect verifier. + +On ARC-AGI-2, the standard inference pipeline applies data augmentations and votes across them. +PTRM adds K stochastic rollouts per augmentation. For selection, we pick the rollout with the +highest Q value within each augmentation, then vote across augmentations as in the standard pipeline. +With K=25 and σ=0.2, PTRM lifts pass@1 from 7.36% to 8.47% and pass@100 from 14.31% to +15.97% over our deterministic TRM baseline, while matching it at pass@2. + +Sudoku-Extreme Maze-Hard +ARC-AGI-2 +Acc. (%) +Acc. (%) pass@1 pass@2 pass@100 + +Method + +# Params + +HRM +TRM + +27M +5M / 7M† + +55.0 +87.4 + +74.5 +85.3 + +– +– + +5.0 +7.8 + +– +– + +Ours +Standard TRM, our reproduction 5M / 7M† +PTRM +5M / 7M† + +87.28 +98.75 + +83.80 +86.73 + +7.36 +8.47 + +9.72 +9.72 + +14.31 +15.97 + +Table 3: Sudoku-Extreme, Maze-Hard, and ARC-AGI-2 results. For Sudoku-Extreme, K=100, +D=64, σ=0.3. For Maze-Hard, K=100, D=16, σ=1.0. For ARC-AGI-2, K=25, D=16, σ=0.2. +pass@k for ARC-AGI-2 reports the top-k predictions from the augmentation-voting pipeline. PTRM +shows an accuracy improvement over standard TRM across all 3 benchmarks. † Following [1], 5M +for Sudoku-Extreme (TRM-MLP), 7M for Maze-Hard and ARC-AGI-2 (TRM-Att). + +5.4 + +Q head selection as σ grows + +With a higher σ value, PTRM finds many correct solutions that the deterministic inference misses. +For instance, on Maze-Hard, the deterministic model solves 83.8% of puzzles, but PTRM raises +pass@K to nearly 96%. The extent to which PTRM helps depends on the task, but on every dataset +we tested, it unlocks correct solutions well beyond the deterministic model’s reach. +TRM’s jointly trained Q head serves as a strong verifier on most tasks. On PPBench and SudokuExtreme, best-Q@K reaches values within a point of the saturated pass@K, so PTRM’s exploration +translates directly into accuracy gains. On Maze-Hard, more exploration (higher σ) produces +significantly more correct rollouts, but the existing Q head is not able to identify them, leaving +performance on the table. The gap between best-Q@K and pass@K represents headroom for a +stronger verifier which is left for future work. Appendix B reports the full σ sweep. +8 + +6 + +Related Work + +A long line of work explores recursive computation for iterative reasoning and representation refinement. Early examples include Universal Transformers [10], Mixture-of-Recursions [11], Deep +Thinking models [12, 13, 14], and HRM [2], all of which investigate the use of repeated computation +steps to improve reasoning performance. More recent work has introduced methods to substantially +accelerate TRM training [15], while TRM-style recursive architectures have also been extended to +language modeling tasks [16]. +Building on this broader perspective of recursive computation, a growing body of work studies +latent-space reasoning through the reuse of hidden states. Hao et al. [17] propose continuous +“thinking tokens” derived from Chain-of-Thought (CoT) traces [18], which are autoregressively +generated and appended to the model context, enabling reasoning directly in latent space without +producing intermediate textual outputs. Similarly, Zhu et al. [19] formalize learning by superposition +and demonstrate improvements on tasks such as graph reachability. By avoiding explicit token +sampling and implicitly representing multiple reasoning trajectories, these approaches may mitigate +the unfaithfulness and backtracking often observed in standard autoregressive reasoning [20, 21]. +Related to our work, Baek et al. [22] propose a generative version of TRM where the hidden state +z is sampled instead of deterministic. This improves performance on multiple tasks, but requires +retraining. Efstathiou and Balwani [23] (concurrent work) propose a similar test-time compute +method where they only apply noise in the initial hidden state z, while we apply noise at every +supervision step. Furthermore, they test their method on a small subset of the Sudoku-Extreme +dataset, and treat it as a proof-of-concept that needs to be developed and tested further. Note that +Baek et al. [22] also tested applying noise to the initial z with TRM and obtained negative results (no +improvement in accuracy on two datasets). +Our observations in Sec. 3 are consistent with the mechanistic analysis of Ren and Liu [5], who +identify spurious fixed points in HRM’s latent dynamics on Sudoku-Extreme. Their method mitigates +these attractors through a combination of task-specific training data augmentation, inference-time +input perturbations, and model bootstrapping across training checkpoints, thereby effectively increasing test-time compute. However, these interventions are comparatively less general and less +computationally efficient. In contrast, we observe analogous basin structure in TRM across multiple +puzzle types and achieve attractor escape using a substantially simpler, task-agnostic mechanism: +injecting Gaussian noise into the latent state at each supervision step while using a single deterministic +checkpoint. + +7 + +Conclusion + +In this work, we introduced Probabilistic TRM (PTRM), a novel test-time scaling paradigm for +Tiny Recursive Models (TRM) through parallel exploration and selection. This approach scales +test-time compute using width (K parallel rollouts), yielding substantially larger gains than depth +scaling (increasing deep recursion steps) alone. PTRM requires no retraining and does not rely on +task-specific data augmentations making it extremely easy to use and versatile. +By scaling both width and depth, PTRM obtains significant gains in accuracy when tested on a wide +selection of puzzles. On PPBench (Sudoku, Lightup, Nurikabe, Heyawake, Tapa puzzles), PTRM +nearly obtains twice the accuracy (91.2%; $0.001 cost) of ensemble of SOTA LLMs (55.1%; $38.51 +cost) at less than 0.0001x the cost. Furthermore, PTRM improves accuracy on Sudoku (from 87.4% +to 98.75%), Maze-Hard (from 83.80% to 86.73%), and ARC-AGI (from 7.8% to 8.47% pass@1). +Limitations. Our experiments focus on reasoning puzzles rather than general tasks. We only test +on a subset of PPBench puzzles. We are limited to puzzles with a small grid-size due to limited +computational resources. It is not guaranteed that the method works as well for all types of problems +(e.g., accuracy gains on ARC-AGI-2 and Heyawake are smaller). +Future work. It would be interesting to understand why some puzzles benefit from test-time scaling +more than others. We suspect that problems that are harder to verify (e.g., ARC-AGI-2) benefit less +from PTRM because the Q head may struggle to distinguish correct solutions from incorrect ones. +Developing stronger verifiers than the existing Q head is an interesting direction for future work. +9 + +References +[1] Alexia Jolicoeur-Martineau. Less is more: Recursive reasoning with tiny networks. arXiv +preprint arXiv:2510.04871, 2025. +[2] Guan Wang, Jin Li, Yuhao Sun, Xing Chen, Changling Liu, Yue Wu, Meng Lu, Sen Song, and +Yasin Abbasi Yadkori. Hierarchical reasoning model. arXiv preprint arXiv:2506.21734, 2025. +[3] François Chollet. On the measure of intelligence. arXiv preprint arXiv:1911.01547, 2019. +[4] Francois Chollet, Mike Knoop, Gregory Kamradt, Bryan Landers, and Henry Pinkard. Arcagi-2: A new challenge for frontier ai reasoning systems. arXiv preprint arXiv:2505.11831, +2025. +[5] Zirui Ren and Ziming Liu. Are your reasoning models reasoning or guessing? a mechanistic +analysis of hierarchical reasoning models. arXiv preprint arXiv:2601.10679, 2026. +[6] Charlie Snell, Jaehoon Lee, Kelvin Xu, and Aviral Kumar. Scaling llm test-time compute optimally can be more effective than scaling model parameters. arXiv preprint arXiv:2408.03314, +2024. +[7] Alex Graves. Adaptive computation time for recurrent neural networks. arXiv preprint +arXiv:1603.08983, 2016. +[8] Justin Waugh. Pencil puzzle bench: A benchmark for multi-step verifiable reasoning. arXiv +preprint arXiv:2603.02119, 2026. +[9] Vast.ai. Rent h100 pcie gpus on vast.ai. https://vast.ai/pricing/gpu/H100-PCIE, 2026. +Accessed: 2026-05-01. +[10] Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Łukasz Kaiser. Universal transformers. arXiv preprint arXiv:1807.03819, 2018. +[11] Sangmin Bae, Yujin Kim, Reza Bayat, Sungnyun Kim, Jiyoun Ha, Tal Schuster, Adam Fisch, +Hrayr Harutyunyan, Ziwei Ji, Aaron Courville, et al. Mixture-of-recursions: Learning dynamic +recursive depths for adaptive token-level computation. arXiv preprint arXiv:2507.10524, 2025. +[12] Avi Schwarzschild, Eitan Borgnia, Arjun Gupta, Furong Huang, Uzi Vishkin, Micah Goldblum, +and Tom Goldstein. Can you learn an algorithm? generalizing from easy to hard problems with +recurrent networks. Advances in Neural Information Processing Systems, 34:6695–6706, 2021. +[13] Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang, Micah Goldblum, +and Tom Goldstein. End-to-end algorithm synthesis with recurrent networks: Extrapolation +without overthinking. Advances in Neural Information Processing Systems, 35:20232–20242, +2022. +[14] Jay Bear, Adam Prugel-Bennett, and Jonathon Hare. Rethinking deep thinking: Stable learning +of algorithms using lipschitz constraints. Advances in Neural Information Processing Systems, +37:97027–97052, 2024. +[15] Navid Hakimi. Form follows function: Recursive stem model. arXiv preprint arXiv:2603.15641, +2026. +[16] Yinxi Li, Jiaao Chen, Fang Wu, Jiakai Yu, Heli Qi, Weihao Xuan, Haokai Zhao, Pengyu Nie, +Di Jin, and Xiangru Tang. Learning multi-step reasoning via persistent latent state propagation. +In Workshop on Latent {\&} Implicit Thinking {\textendash} Going Beyond CoT Reasoning, +2026. +[17] Shibo Hao, Sainbayar Sukhbaatar, DiJia Su, Xian Li, Zhiting Hu, Jason Weston, and Yuandong +Tian. Training large language models to reason in a continuous latent space. arXiv preprint +arXiv:2412.06769, 2024. +[18] Jason Wei, Xuezhi Wang, Dale Schuurmans, Maarten Bosma, Fei Xia, Ed Chi, Quoc V Le, +Denny Zhou, et al. Chain-of-thought prompting elicits reasoning in large language models. +Advances in neural information processing systems, 35:24824–24837, 2022. +10 + +[19] Hanlin Zhu, Shibo Hao, Zhiting Hu, Jiantao Jiao, Stuart Russell, and Yuandong Tian. Reasoning +by superposition: A theoretical perspective on chain of continuous thought. arXiv preprint +arXiv:2505.12514, 2025. +[20] Tamera Lanham, Anna Chen, Ansh Radhakrishnan, Benoit Steiner, Carson Denison, Danny +Hernandez, Dustin Li, Esin Durmus, Evan Hubinger, Jackson Kernion, et al. Measuring +faithfulness in chain-of-thought reasoning. arXiv preprint arXiv:2307.13702, 2023. +[21] Yanda Chen, Joe Benton, Ansh Radhakrishnan, Jonathan Uesato, Carson Denison, John Schulman, Arushi Somani, Peter Hase, Misha Wagner, Fabien Roger, et al. Reasoning models don’t +always say what they think. arXiv preprint arXiv:2505.05410, 2025. +[22] Junyeob Baek, Mingyu Jo, Minsu Kim, Yoshua Bengio, and Sungjin Ahn. Generative recursive +reasoning models. ICLR 2026 Workshop on AI with Recursive Self-Improvement, 2026. +[23] Andreas Efstathiou and Aishwarya Balwani. Recursive reasoning as attractor landscape search: +Mechanistic dynamics of the tiny recursive model. Workshop on Latent & Implicit Thinking – Going Beyond CoT Reasoning, 2026. URL https://openreview.net/forum?id= +kKps9W1K7n. + +11 + +A + +Implementation Details + +A.1 + +Compute + +We train and evaluate all models on a single NVIDIA H100 80GB GPU. PTRM introduces no +additional training cost over standard TRM since it operates entirely at inference time. +A.2 + +Models + +All experiments use the standard TRM backbone [1] with the released architecture and training recipes. +Following the TRM paper, we use the MLP variant (TRM-MLP, 5M parameters) for Sudoku-Extreme +and the attention variant (TRM-Att, 7M parameters) for Maze-Hard, ARC-AGI-2, and PPBench. +Layout and hyperparameters are unchanged from TRM. +A.3 + +PPBench dataset construction + +Sudoku-Extreme, Maze-Hard, and ARC-AGI-2 use the same checkpoints and data splits as TRM. +The PPBench dataset is more recent and has previously been used only with frontier LLMs, so we +detail how we built our training, validation, and golden splits. +Source. PPBench contains 62,231 constraint-satisfaction pencil puzzles spanning 94 puzzle types. +Of these, 300 puzzles (15 puzzles × 20 types) are held out as the golden benchmark set by Waugh [8]. +Filtering. From the remaining 61,931 puzzles we hold out a validation set by sampling 100 puzzles +from each puzzle type (50 for tapa, due to its smaller base size), and the rest forms the training +set. We then filter all three sets (training, validation, golden) to retain only puzzles of six types +(sudoku, lightup, nurikabe, shakashaka, heyawake, tapa) at fixed grid sizes: 9×9 for sudoku +and 10×10 for the others. Sudoku grids are padded with a pad token to 10×10, giving a uniform +sequence length of seq_len = 100 across all six puzzle types. The deterministic TRM baseline +reaches 100% accuracy on shakashaka, so we exclude it from per-puzzle accuracy reporting (no +headroom to compare against PTRM). +Augmentation. Each training puzzle is expanded into 10 examples using two augmentations: 1) +trajectory sampling, where the input is set to a random intermediate solve state along the puzzle’s +solution trajectory rather than always the empty initial grid, while the label is always the fully solved +grid; and 2) dihedral transformation, where a random dihedral transformation of a square grid, among +the 8 possibilities given by 4 rotations × 2 {identity, reflection}, is applied to both the input and the +label. For each puzzle, the first example is the unaugmented (initial state, solved) pair. The remaining +9 are randomly sampled (trajectory and dihedral transform). Validation and golden splits are not +augmented. +Resulting splits. The merged multi-type splits use a unified vocabulary of 294 tokens and seq_len = +100. Per-type sample counts are reported in Table 4. +puzzle type + +train + +val + +golden + +sudoku +lightup +nurikabe +heyawake +tapa +shakashaka∗ + +7,810 +9,504 +15,180 +42,108 +3,663 +20,702 + +97 +65 +55 +70 +26 +62 + +15 +8 +9 +7 +10 +12 + +total + +98,967 + +375 + +61 + +Table 4: Per-puzzle-type sample counts in the PPBench splits used in training and evaluation. +∗ +Shakashaka is included in training but excluded from per-puzzle accuracy reporting because deterministic TRM already solves all evaluated shakashaka puzzles. + +12 + +B + +Noise Ablation + +We ablate the inference noise level σ on three benchmarks at K=25 (K=100 for Maze-Hard) and +D=16 to keep the sweep tractable. For Sudoku-Extreme we randomly sample 1000 puzzles from the +test set for the same reason. Figure 7 shows pass@K, best-Q@K, and mode@K as a function of σ, +averaged over three random seeds. +pass@K + +Sudoku-Extreme + +100 + +mode@K + +K = 1 baseline + +Maze-Hard + +ARC-AGI-2 (within-aug) +5.5 + +96 + +90 + +accuracy (%) + +best-Q@K + +94 + +80 + +5.0 + +92 + +70 + +90 + +60 + +88 + +50 + +86 + +40 + +84 + +30 +0.0 + +0.2 + +0.4 + +0.6 + +0.8 + +1.0 + +82 +0.0 + +4.5 +4.0 +3.5 +0.2 + +0.4 + +0.6 + +0.8 + +1.0 + +0.0 + +0.2 + +0.4 + +0.6 + +0.8 + +1.0 + +Figure 7: pass@K, best-Q@K, and mode@K across σ per rollout batch. On every task, +increasing the inference noise consistently produces more correct rollouts (pass@K, blue) up to +a task-dependent σ value. The Q head (best-Q@K, orange) tracks the pass@K ceiling closely +on Sudoku-Extreme and leaves a larger gap on Maze-Hard and ARC-AGI-2. The shaded region +represents the verifier headroom (accuracy that a better verifier could extract). mode@K (green) has +the edge over the Q head only on Maze-Hard. For ARC-AGI-2, metrics are per puzzle/augmentation +to isolate the Q head’s verification abilities from the augmentation pipeline. +On Maze-Hard pass@K climbs from 83.8% (deterministic) to nearly 96% by σ≈1.0 and then +plateaus. On Sudoku-Extreme it is already near its ceiling at σ=0.1 and stays roughly flat across the +sweep. On ARC-AGI-2 it peaks near σ=0.6 before declining. Q head selection nearly matches the +ceiling (maximum pass@K) on Sudoku-Extreme while best-Q@K peaks at 98.5% (within a point of +pass@K’s peak of 99.3%). On the other hand, the gap between best-Q@K and maximum pass@K +is more pronounced on Maze-Hard and ARC-AGI-2 (headroom a stronger verifier could close). + +C + +Q-guided Langevin sampling + +We initially explored Langevin sampling (using the Q head gradient) as a more principled exploration +mechanism than the Gaussian noise injection used in PTRM. The idea is to better guide the stochastic +search by additionally steering each rollout (using the Q head gradient) toward regions of high Q +value. We ultimately found that the gain from this approach was entirely attributable to the Langevin +noise term, with the gradient component contributing nothing measurable on top of the equivalent +recurrent noise of Sec. 4. We document the approach here as a negative result. +Motivation. The Q head is trained as a correctness predictor over latent states. Let fQ (z) denote +the head’s scalar output. We treated E(z) = − log sigmoid(fQ (z)) as an energy function over latent +space. Empirical observations during early experiments suggested that regions of low E correspond +to good basins from which the decoded answer is likely correct. PCA visualizations of the latent +dynamics showed that ∇z fQ points toward the good-basin region from both good-basin (correct) and +bad-basin (incorrect) latents (Figure 8). This made ∇z fQ look like a valuable direction along which +to push latents. +Method. We sample from the target distribution p(z) ∝ e−E(z) = sigmoid(fQ (z)) via Langevin +dynamics where at the end of each deep recursion step t = 1, . . . , D we apply N Langevin steps to +the latent, +p +z ← z − η ∇z E(z) + 2η ξ, ξ ∼ N (0, I), +The number of Langevin steps N is the additional scaling axis under this scheme. +13 + +t=0 + +t=5 + +t = 10 + +t = 15 + +Correct (21) +Incorrect (4) +Q + +Figure 8: y latents and their ∇z fQ gradients projected into the principal plane at several recursive/supervision steps, for multiple rollouts (using recurrent noise) of a single puzzle (correct rollouts +in green, incorrect in red). Arrows are drawn at each latent in the direction of ∇z fQ . From both +good-basin and bad-basin latents, gradients point toward the good-basin region. This visualization +motivated the Langevin sampling experiment described below. +Tractable gradient computation. TRM’s original Q head is a linear projection on a single token, +fQ (y) = w⊤ y[:, 0]+b, so its gradient with respect to this head’s input is a constant vector independent +of z. For ∇z fQ to be input-dependent, the gradient must flow back through the last latent recursion. +This works but requires backpropagating through a full latent recursion at every Langevin step, which +scales poorly with N . To make guidance tractable for large N , we replaced the linear Q head with +an attention-pooled variant that reads the full latent and produces a scalar through a small nonlinear +network. With this head, ∇z fQ can be computed by backpropagating through the head alone, which +is ∼8× faster per step and does not sacrifice accuracy. +The gain came from the noise, +√ not the gradient. Comparing Langevin sampling against a noiseonly ablation (with the same 2η ξ, but with the −η ∇z E(z) term zeroed out) produced essentially +identical accuracy at matched N . The gradient component contributed nothing measurable on +top of the equivalent recurrent noise. This prompted us to focus on the noise-only formulation in +Sec. 4, which is much more impactful since it is: 1) significantly simpler (no retraining, no test-time +backpropagation), 2) applicable to any TRM checkpoint out of the box, and 3) equally effective. + +D + +Per-puzzle accuracy on the PPBench validation set + +The main paper reports per-puzzle accuracy on the PPBench golden set (Table 1) for direct comparability with the LLM evaluations from Waugh [8] who used that set. For a lower-variance complement, +Table 5 reports results on our validation set (313 puzzles across the five reported types vs. 49 for +golden). Trends match the golden-set results: depth scaling alone (K=1, D=48) provides a small lift, +and combining depth with stochastic rollouts (K=100, D=48, σ=0.2) raises aggregate best-Q@K +from 76.4% to 90.4%, a 14.0 percentage-point improvement. The biggest gains again are on puzzles +where the deterministic baseline has the most headroom (tapa ∼ 40% to 71.8%, sudoku ∼ 69% +to 93.3%). Types where the baseline is already near ceiling (heyawake at 96.7%) increase only +marginally. +% accuracy +Direct prediction +TRM (K=1, D=16) +TRM (K=1, D=48) +PTRM, best-Q@K (K=100, D=48) + +# Params sudoku lightup nurikabe heyawake +27M +7M +7M +7M + +0.0 +68.7 +74.0 +93.3 + +10.0 +83.3 +84.0 +93.3 + +4.0 +76.0 +76.7 +84.7 + +14.0 +96.7 +98.0 +100 + +tapa + +agg. + +0.0 +6.2 +39.7 76.4 +41.0 78.3 +71.8 90.4 + +Table 5: PPBench per-puzzle accuracy on the validation set. PTRM uses the same backbone as the +deterministic TRM. Results on the larger validation set follow the same trends as on the golden set. + +14 + +
\ No newline at end of file diff --git a/papers/trm_2510.04871.pdf b/papers/trm_2510.04871.pdf Binary files differnew file mode 100644 index 0000000..4307790 --- /dev/null +++ b/papers/trm_2510.04871.pdf |
