summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-17 11:19:27 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-17 11:19:27 -0500
commitd12722525fc010a3910b5152c72654a2ade5eac4 (patch)
treeadbfd19bef487e1c959856d9bf7f881709684a7e
Initial import
-rw-r--r--.gitignore10
-rw-r--r--diag/__init__.py0
-rw-r--r--diag/aggregate.py54
-rw-r--r--diag/cin_color.py144
-rw-r--r--diag/datasets.py70
-rw-r--r--diag/esan_color.py145
-rw-r--r--diag/ff_color.py75
-rw-r--r--diag/lyap.py83
-rw-r--r--diag/models.py42
-rw-r--r--diag/peptides_depth.py66
-rw-r--r--diag/ppgn_color.py209
-rw-r--r--diag/ptrm_color.py85
-rw-r--r--diag/run_archA.sh16
-rw-r--r--diag/run_archB.sh21
-rw-r--r--diag/run_cin.sh8
-rw-r--r--diag/run_color.sh15
-rw-r--r--diag/run_cycle.sh13
-rw-r--r--diag/run_diag.sh16
-rw-r--r--diag/run_esan.sh8
-rw-r--r--diag/run_ff.sh8
-rw-r--r--diag/run_le.sh12
-rw-r--r--diag/run_pe.sh24
-rw-r--r--diag/run_pe2.sh18
-rw-r--r--diag/run_pe3.sh23
-rw-r--r--diag/run_pna.sh13
-rw-r--r--diag/run_ppgn.sh8
-rw-r--r--diag/run_real.sh12
-rw-r--r--diag/run_rec.sh13
-rw-r--r--diag/run_seeds.sh22
-rw-r--r--diag/selftest_wl.py53
-rw-r--r--diag/train_color.py347
-rw-r--r--diag/train_cycle.py183
-rw-r--r--diag/train_diag.py161
-rw-r--r--diag/train_real.py139
-rw-r--r--diag/train_rec.py174
-rw-r--r--diag/wl.py166
-rw-r--r--papers/gram_2605.19376.pdfbin0 -> 5136087 bytes
-rw-r--r--papers/hrm_2506.21734.pdfbin0 -> 2268976 bytes
-rw-r--r--papers/ptrm_2605.19943.pdfbin0 -> 635907 bytes
-rw-r--r--papers/ptrm_2605.19943.txt1418
-rw-r--r--papers/trm_2510.04871.pdfbin0 -> 427299 bytes
41 files changed, 3874 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..1a9b622
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,10 @@
+__pycache__/
+*.py[cod]
+
+.venv/
+venv/
+
+data/
+runs/
+
+.DS_Store
diff --git a/diag/__init__.py b/diag/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/diag/__init__.py
diff --git a/diag/aggregate.py b/diag/aggregate.py
new file mode 100644
index 0000000..b0f737a
--- /dev/null
+++ b/diag/aggregate.py
@@ -0,0 +1,54 @@
+"""Aggregate multi-seed coloring results -> mean+/-std per (grad_mode, pe, contract)."""
+import glob, json
+import numpy as np
+from collections import defaultdict
+
+R = '/home/yurenh2/rrog/runs'
+
+
+def ms(xs):
+ a = np.array([x for x in xs if x is not None], dtype=float)
+ return f"{a.mean():.3f}±{a.std():.3f} (n={len(a)})" if len(a) else "—"
+
+
+def load(pat):
+ out = []
+ for f in glob.glob(pat):
+ try:
+ out.append(json.load(open(f)))
+ except Exception:
+ pass
+ return out
+
+
+def key(d):
+ return (d.get('conv', 'gin'), d.get('pe'), d.get('grad_mode'), 'ctr' if d.get('contract') else '-')
+
+
+solve, le, ml = defaultdict(list), defaultdict(list), defaultdict(list)
+pk, ls, au, det = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
+
+for d in load(f"{R}/color_*.json"):
+ if 'solve_rate' in d and d.get('pe') is not None:
+ solve[key(d)].append(d['solve_rate'])
+for d in load(f"{R}/le_color_*.json"):
+ le[key(d)].append(d.get('auroc'))
+ ml[key(d)].append(d.get('mean_lam'))
+for d in load(f"{R}/ptrm_color_*.json"):
+ k = key(d); det[k].append(d.get('det'))
+ s2 = d.get('sigmas', {}).get('0.2')
+ if s2:
+ pk[k].append(s2.get('passk')); ls[k].append(s2.get('lamsel')); au[k].append(s2.get('auroc'))
+
+print("=== best solve_rate (deterministic, EMA) ===")
+for k in sorted(solve, key=str):
+ print(f" {k}: {ms(solve[k])}")
+print("=== LE AUROC(fail|lambda1) ===")
+for k in sorted(le, key=str):
+ print(f" {k}: {ms(le[k])}")
+print("=== LE mean_lambda1 (forced-contraction dose) ===")
+for k in sorted(ml, key=str):
+ print(f" {k}: {ms(ml[k])}")
+print("=== PTRM sigma=0.2 ===")
+for k in sorted(pk, key=str):
+ print(f" {k}: det {ms(det[k])} | pass@K {ms(pk[k])} | lambda-sel {ms(ls[k])} | AUROC {ms(au[k])}")
diff --git a/diag/cin_color.py b/diag/cin_color.py
new file mode 100644
index 0000000..324215f
--- /dev/null
+++ b/diag/cin_color.py
@@ -0,0 +1,144 @@
+"""H4: Recursive CIN-lite (topological / cell-complex) backbone on graph 3-coloring.
+
+Augment each graph with ring 2-cells from the cycle basis: add one hypernode per basis cycle,
+connected to its member nodes. Messages flow node->ring->node (topological message passing over
+rings). Run the shared recursive GIN on the augmented (nodes + ring-cells) graph; decode colors
+on the ORIGINAL nodes only. Self-contained: train (EMA, best solve) + LE + PTRM; writes
+color_/le_/ptrm_ JSON (conv='cin') for diag/aggregate.py.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/cin_color.py --grad_mode full --seed 0
+"""
+import argparse, json, os, time
+import numpy as np
+import networkx as nx
+import torch
+import torch.nn.functional as F
+from torch_geometric.data import Data, Batch
+from diag.train_color import make_split, featurize, RecGINColor, lyap1, OUT
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+N = 50
+
+
+def dense_A(edge_index, n):
+ A = torch.zeros(n, n)
+ if edge_index.shape[1]:
+ A[edge_index[0], edge_index[1]] = 1.0
+ return torch.maximum(A, A.t())
+
+
+def augment(g):
+ n = g['n']; ei = g['edge_index']
+ G = nx.Graph(); G.add_nodes_from(range(n))
+ if ei.shape[1]:
+ G.add_edges_from(ei.t().tolist())
+ rings = nx.cycle_basis(G)
+ R = len(rings)
+ src, dst = ei[0].tolist(), ei[1].tolist()
+ for r, cyc in enumerate(rings):
+ rn = n + r
+ for node in cyc:
+ src += [node, rn]; dst += [rn, node]
+ aug_ei = torch.tensor([src, dst], dtype=torch.long) if src else torch.zeros((2, 0), dtype=torch.long)
+ d = g['rfeat'].shape[1]
+ nf = torch.cat([g['rfeat'], torch.tensor([[1.0, 0.0]]).repeat(n, 1)], dim=1)
+ rf = torch.cat([torch.zeros(R, d), torch.tensor([[0.0, 1.0]]).repeat(R, 1)], dim=1) if R else torch.zeros(0, d + 2)
+ return Data(x=torch.cat([nf, rf], dim=0), edge_index=aug_ei, num_nodes=n + R)
+
+
+def conf_of(logits, A):
+ col = logits.argmax(-1)
+ return int(((col.unsqueeze(0) == col.unsqueeze(1)) & (A > 0)).sum().item() // 2)
+
+
+@torch.no_grad()
+def solve_rate(model, graphs, dev, sample=300):
+ model.eval(); solved = 0
+ for g in graphs[:sample]:
+ d = g['aug'].to(dev)
+ lg = model(d.x, d.edge_index)[-1][:N]
+ solved += int(conf_of(lg, g['A'].to(dev)) == 0)
+ return solved / len(graphs[:sample])
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--M', type=int, default=16)
+ ap.add_argument('--seed', type=int, default=0); ap.add_argument('--sigma', type=float, default=0.2)
+ ap.add_argument('--K', type=int, default=16); ap.add_argument('--n_graphs', type=int, default=150)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ rng = np.random.default_rng(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ for g in tr + te:
+ g['A'] = dense_A(g['edge_index'], g['n']); g['aug'] = augment(g)
+ in_dim = tr[0]['rfeat'].shape[1] + 2
+ model = RecGINColor(in_dim, 128, 3, grad_mode=args.grad_mode, conv='gin').to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+
+ t0 = time.time(); best = -1; best_state = None
+ for ep in range(args.epochs):
+ model.train(); order = rng.permutation(len(tr))
+ for s0 in range(0, len(order) - args.M, args.M):
+ sel = order[s0:s0 + args.M]
+ b = Batch.from_data_list([tr[i]['aug'] for i in sel]).to(dev)
+ opt.zero_grad()
+ logits = model(b.x, b.edge_index, b.batch)[-1]
+ orig = torch.stack([logits[b.ptr[gi]:b.ptr[gi] + N] for gi in range(len(sel))]) # [M,N,k]
+ A = torch.stack([tr[i]['A'] for i in sel]).to(dev)
+ p = F.softmax(orig, -1)
+ loss = (torch.einsum('bik,bjk->bij', p, p) * A).sum() / (A.sum() + 1e-9)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr = solve_rate(model, te, dev)
+ if sr > best:
+ best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(bk)
+ print(f"ep{ep+1} solve={sr:.3f}", flush=True)
+ model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval()
+
+ nstep = model.n_sup * model.T
+ lams, fails = [], []; passk = lamsel = rand = 0; Lr, Sr = [], []
+ for gi, g in enumerate(te[:args.n_graphs]):
+ d = g['aug'].to(dev); A = g['A'].to(dev)
+ lam0 = lyap1(model, d.x, d.edge_index, nstep, dev, seed=gi)
+ c0 = conf_of(model(d.x, d.edge_index)[-1][:N], A)
+ lams.append(lam0); fails.append(int(c0 > 0))
+ confs, rl = [], []
+ for j in range(args.K):
+ confs.append(conf_of(model(d.x, d.edge_index, noise=True)[-1][:N], A))
+ rl.append(lyap1(model, d.x, d.edge_index, nstep, dev, seed=1000 * gi + j))
+ confs, rl = np.array(confs), np.array(rl); sv = confs == 0
+ passk += int(sv.any()); lamsel += int(sv[rl.argmin()]); rand += int(sv[0])
+ Lr += rl.tolist(); Sr += sv.tolist()
+ lams, fails = np.array(lams), np.array(fails); s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ n = len(te[:args.n_graphs])
+ print(f"[cin/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} "
+ f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)")
+ base = f"cin_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}"
+ com = {'conv': 'cin', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed}
+ json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'))
+ json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())}, open(os.path.join(OUT, f"le_color_{base}.json"), 'w'))
+ json.dump({**com, 'det': 1 - float(fails.mean()),
+ 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n}}},
+ open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'))
+ print(" wrote", base)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/datasets.py b/diag/datasets.py
new file mode 100644
index 0000000..cbd236e
--- /dev/null
+++ b/diag/datasets.py
@@ -0,0 +1,70 @@
+"""Synthetic graph datasets for the 1-WL diagnosis."""
+import numpy as np
+import networkx as nx
+
+
+def _nx_to_edge_index(G):
+ G = nx.convert_node_labels_to_integers(G)
+ n = G.number_of_nodes()
+ if G.number_of_edges() == 0:
+ return n, np.zeros((2, 0), dtype=np.int64)
+ e = np.array(list(G.edges()), dtype=np.int64).T
+ ei = np.concatenate([e, e[::-1]], axis=1) # undirected -> both directions
+ return n, ei
+
+
+def circulant(N, offsets):
+ G = nx.Graph()
+ G.add_nodes_from(range(N))
+ for i in range(N):
+ for s in offsets:
+ G.add_edge(i, (i + s) % N)
+ return G
+
+
+CSL_SKIPS = [2, 3, 4, 5, 6, 9, 11, 12, 13, 16] # 10 classes, N=41 (Murphy et al. 2019)
+
+
+def build_csl(n_per_class=15, N=41, seed=0):
+ """Circular Skip Links: all graphs 4-regular -> 1-WL collapses to one color (pure H2 anchor)."""
+ rng = np.random.default_rng(seed)
+ data = []
+ for cls, s in enumerate(CSL_SKIPS):
+ for _ in range(n_per_class):
+ G = circulant(N, [1, s])
+ perm = rng.permutation(N)
+ G = nx.relabel_nodes(G, {i: int(perm[i]) for i in range(N)})
+ n, ei = _nx_to_edge_index(G)
+ data.append({'n': n, 'edge_index': ei, 'y': cls})
+ return data
+
+
+def build_triangle_count(n_graphs=600, n_nodes=20, kind='regular', deg=3, p=0.2, seed=0):
+ """Graph-level triangle-count regression. 1-WL cannot count triangles -> measurable H2 floor."""
+ rng = np.random.default_rng(seed)
+ data, tries = [], 0
+ while len(data) < n_graphs and tries < n_graphs * 30:
+ tries += 1
+ sd = int(rng.integers(1 << 30))
+ try:
+ G = (nx.random_regular_graph(deg, n_nodes, seed=sd) if kind == 'regular'
+ else nx.gnp_random_graph(n_nodes, p, seed=sd))
+ except Exception:
+ continue
+ tri = sum(nx.triangles(G).values()) // 3
+ n, ei = _nx_to_edge_index(G)
+ data.append({'n': n, 'edge_index': ei, 'y': float(tri)})
+ return data
+
+
+def canonical_pairs():
+ """Graphs for instrument self-test (known 1-WL outcomes)."""
+ pairs = [('C6', nx.cycle_graph(6)),
+ ('2C3', nx.disjoint_union(nx.cycle_graph(3), nx.cycle_graph(3))),
+ ('P4', nx.path_graph(4)),
+ ('K1,3', nx.star_graph(3))]
+ out = {}
+ for name, G in pairs:
+ n, ei = _nx_to_edge_index(G)
+ out[name] = {'n': n, 'edge_index': ei, 'tri': sum(nx.triangles(G).values()) // 3}
+ return out
diff --git a/diag/esan_color.py b/diag/esan_color.py
new file mode 100644
index 0000000..7be050c
--- /dev/null
+++ b/diag/esan_color.py
@@ -0,0 +1,145 @@
+"""H3: Recursive ESAN (subgraph GNN, DS-GNN node-marking bag) on graph 3-coloring.
+
+Per graph, pick S anchor nodes; each view = graph + a 1-hot mark on the anchor. Run the SHARED
+recursive GIN on all views, average node-logits over views (DeepSets). Marking breaks node
+symmetry -> >1-WL. Self-contained: train (EMA, best solve) + LE (lambda on a marked view,
+bucket by aggregate solve) + PTRM (K noisy aggregate forwards); writes color_/le_/ptrm_ JSON
+(conv='esan') for diag/aggregate.py.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/esan_color.py --grad_mode full --seed 0
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch_geometric.data import Data, Batch
+from diag.train_color import make_split, featurize, RecGINColor, lyap1, OUT
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+S = 4
+
+
+def dense_A(edge_index, n):
+ A = torch.zeros(n, n)
+ if edge_index.shape[1]:
+ A[edge_index[0], edge_index[1]] = 1.0
+ return torch.maximum(A, A.t())
+
+
+def views_of(g, anchors):
+ out = []
+ for a in anchors:
+ mark = torch.zeros(g['n'], 1); mark[a] = 1.0
+ out.append(Data(x=torch.cat([g['rfeat'], mark], dim=1), edge_index=g['edge_index'], num_nodes=g['n']))
+ return out
+
+
+def anchors_for(g, rng):
+ return rng.choice(g['n'], size=min(S, g['n']), replace=False)
+
+
+def esan_logits(model, g, dev, anchors, noise=False):
+ b = Batch.from_data_list(views_of(g, anchors)).to(dev)
+ out = model(b.x, b.edge_index, b.batch, noise=noise)[-1] # [S*n, k]
+ return out.view(len(anchors), g['n'], -1).mean(0) # [n, k]
+
+
+def conf_of(logits, A):
+ col = logits.argmax(-1)
+ return int(((col.unsqueeze(0) == col.unsqueeze(1)) & (A > 0)).sum().item() // 2)
+
+
+@torch.no_grad()
+def solve_rate(model, graphs, dev, rng, sample=300):
+ model.eval(); solved = 0
+ for g in graphs[:sample]:
+ solved += int(conf_of(esan_logits(model, g, dev, anchors_for(g, rng)), g['A'].to(dev)) == 0)
+ return solved / len(graphs[:sample])
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--M', type=int, default=16)
+ ap.add_argument('--seed', type=int, default=0); ap.add_argument('--sigma', type=float, default=0.2)
+ ap.add_argument('--K', type=int, default=16); ap.add_argument('--n_graphs', type=int, default=150)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ rng = np.random.default_rng(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ for g in tr + te:
+ g['A'] = dense_A(g['edge_index'], g['n'])
+ in_dim = tr[0]['rfeat'].shape[1] + 1
+ model = RecGINColor(in_dim, 128, 3, grad_mode=args.grad_mode, conv='gin').to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+
+ t0 = time.time(); best = -1; best_state = None
+ for ep in range(args.epochs):
+ model.train(); order = rng.permutation(len(tr))
+ for s0 in range(0, len(order) - args.M, args.M):
+ sel = order[s0:s0 + args.M]
+ views, As = [], []
+ for i in sel:
+ g = tr[i]; views += views_of(g, anchors_for(g, rng)); As.append(g['A'])
+ b = Batch.from_data_list(views).to(dev)
+ opt.zero_grad()
+ logits = model(b.x, b.edge_index, b.batch, noise=False)[-1].view(args.M, S, 50, 3).mean(1)
+ Ab = torch.stack(As).to(dev)
+ p = F.softmax(logits, -1)
+ loss = (torch.einsum('bik,bjk->bij', p, p) * Ab).sum() / (Ab.sum() + 1e-9)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr = solve_rate(model, te, dev, rng)
+ if sr > best:
+ best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(bk)
+ print(f"ep{ep+1} solve={sr:.3f}", flush=True)
+ model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval()
+
+ lams, fails = [], []; passk = lamsel = rand = 0; Lr, Sr = [], []
+ nstep = model.n_sup * model.T
+ for gi, g in enumerate(te[:args.n_graphs]):
+ anc = anchors_for(g, rng)
+ mark = torch.zeros(g['n'], 1); mark[anc[0]] = 1.0
+ xin = torch.cat([g['rfeat'], mark], dim=1).to(dev); ei = g['edge_index'].to(dev)
+ lam0 = lyap1(model, xin, ei, nstep, dev, seed=gi)
+ c0 = conf_of(esan_logits(model, g, dev, anc), g['A'].to(dev))
+ lams.append(lam0); fails.append(int(c0 > 0))
+ confs, rl = [], []
+ for j in range(args.K):
+ confs.append(conf_of(esan_logits(model, g, dev, anc, noise=True), g['A'].to(dev)))
+ rl.append(lyap1(model, xin, ei, nstep, dev, seed=1000 * gi + j))
+ confs, rl = np.array(confs), np.array(rl); sv = confs == 0
+ passk += int(sv.any()); lamsel += int(sv[rl.argmin()]); rand += int(sv[0])
+ Lr += rl.tolist(); Sr += sv.tolist()
+ lams, fails = np.array(lams), np.array(fails); s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ n = len(te[:args.n_graphs])
+ print(f"[esan/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} "
+ f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)")
+ base = f"esan_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}"
+ com = {'conv': 'esan', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed}
+ json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'))
+ json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())},
+ open(os.path.join(OUT, f"le_color_{base}.json"), 'w'))
+ json.dump({**com, 'det': 1 - float(fails.mean()),
+ 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n}}},
+ open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'))
+ print(" wrote", base)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/ff_color.py b/diag/ff_color.py
new file mode 100644
index 0000000..29c3b45
--- /dev/null
+++ b/diag/ff_color.py
@@ -0,0 +1,75 @@
+"""Control: plain FEEDFORWARD (independent-weight, NO recursion, NO deep-supervision) GNN on
+graph 3-coloring, to test whether the recursive gcn/gat collapse (.003/.04) is caused by the
+RECURSION (shared-weight repeated -> oversmoothing) or is intrinsic to gcn/gat on coloring.
+
+Sweep conv in {gcn,gat} x depth L in {4,8,16}. L=4 ~ normal usage; L=16 tests whether deep
+feedforward also oversmooths. If shallow FF colors well but recursive collapses -> recursion's
+fault. If FF is also ~0 at all L -> intrinsic.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ff_color.py --conv gcn --L 4 --seed 0
+"""
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from diag.train_color import make_split, featurize, make_conv, conflict_loss, solve_stats
+
+
+class FF(nn.Module):
+ def __init__(self, in_dim, hidden, k, L, conv):
+ super().__init__()
+ self.conv_type = conv
+ self.lin_in = nn.Linear(in_dim, hidden)
+ self.convs = nn.ModuleList([make_conv(conv, hidden) for _ in range(L)])
+ self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(L)])
+ self.head = nn.Linear(hidden, k)
+
+ def forward(self, xin, ei, batch=None, noise=False):
+ h = self.lin_in(xin)
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, ei)).relu()
+ return [self.head(h)]
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gcn', 'gat', 'gin', 'sage'], default='gcn')
+ ap.add_argument('--L', type=int, default=4)
+ ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ in_dim = tr[0]['xin'].shape[1]
+ trl = DataLoader([Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr],
+ batch_size=32, shuffle=True, drop_last=True)
+ model = FF(in_dim, 128, 3, args.L, args.conv).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ best = -1
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = conflict_loss(model(b.x, b.edge_index)[-1], b.edge_index)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 30 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr, _ = solve_stats(model, te, dev, sample=300)
+ best = max(best, sr); model.load_state_dict(bk)
+ print(f"[ff/{args.conv}/L{args.L}/s{args.seed}] solve={best:.3f}", flush=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/lyap.py b/diag/lyap.py
new file mode 100644
index 0000000..93b90bf
--- /dev/null
+++ b/diag/lyap.py
@@ -0,0 +1,83 @@
+"""LE diagnostic for the recursive (TRM-ish) GNN — ports the flossing finding to graphs.
+
+Per-graph top Lyapunov exponent lambda1 of the recursion z <- block(z+h0), via Benettin
+power-iteration on a single tangent vector (JVP + renormalize, accumulate log-growth) over
+the model's n_sup*T recursion steps. Bucket graphs by success/failure (rounded ring counts
+exact) and compare lambda1 distributions + AUROC(fail | lambda1) — mirroring
+plot_trm_lyap_hist.py. Hypothesis: failed graphs are MORE chaotic (higher lambda1).
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_full_..._s0.pt
+"""
+import argparse
+import numpy as np
+import torch
+from diag.train_rec import RecGIN
+from diag.train_cycle import prepare
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+
+def build(ck, dev):
+ c = ck['cfg']
+ m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode']).to(dev)
+ m.load_state_dict(ck['state']); m.eval()
+ return m, c
+
+
+def lyap1(model, x, ei, n_steps, dev, seed=0):
+ g = torch.Generator(device=dev).manual_seed(seed)
+ h0 = model.emb(x).detach()
+ z = torch.zeros_like(h0)
+ v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
+ def step_fn(zz):
+ return model.block(zz + h0, ei)
+ lam = 0.0
+ for _ in range(n_steps):
+ z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
+ z = z_next.detach()
+ nv = Jv.norm()
+ lam += torch.log(nv + 1e-12).item()
+ v = (Jv / (nv + 1e-12)).detach()
+ return lam / n_steps
+
+
+@torch.no_grad()
+def predict(model, x, ei, dev):
+ batch = torch.zeros(x.size(0), dtype=torch.long, device=dev)
+ preds, _ = model(x, ei, batch, noise=False)
+ return preds[-1].view(-1)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--ckpt', required=True)
+ ap.add_argument('--n_graphs', type=int, default=300)
+ args = ap.parse_args()
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ ck = torch.load(args.ckpt, weights_only=False)
+ model, cfg = build(ck, dev)
+ ymu, ysd = ck['ymu'].to(dev), ck['ysd'].to(dev)
+ te = prepare('test')
+ n_steps = cfg['n_sup'] * cfg['T']
+
+ lams, fails = [], []
+ for i, r in enumerate(te[:args.n_graphs]):
+ x = r['x'].to(dev); ei = r['edge_index'].to(dev)
+ p = predict(model, x, ei, dev) * ysd + ymu # raw [2]
+ y = r['y'].to(dev) # raw [2]
+ fails.append(int(not torch.all(p.round() == y.round()).item()))
+ lams.append(lyap1(model, x, ei, n_steps, dev, seed=i))
+ lams, fails = np.array(lams), np.array(fails)
+ s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ print(f"[{cfg['grad_mode']}] n={len(lams)} fail_rate={fails.mean():.2f} | "
+ f"lambda1 SUCC mean {s.mean():+.4f} std {s.std():.4f} (n={len(s)}) | "
+ f"FAIL mean {f.mean():+.4f} std {f.std():.4f} (n={len(f)}) | "
+ f"sep(fail-succ)={f.mean()-s.mean() if len(s) and len(f) else float('nan'):+.4f} | "
+ f"AUROC(fail|lambda1)={auc:.3f} | mean_lambda1={lams.mean():+.4f}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/models.py b/diag/models.py
new file mode 100644
index 0000000..0aa3106
--- /dev/null
+++ b/diag/models.py
@@ -0,0 +1,42 @@
+"""GIN (1-WL-tight) and GCN (<1-WL) backbones for the diagnosis."""
+import torch.nn as nn
+from torch_geometric.nn import GINConv, GCNConv, global_add_pool, global_mean_pool
+
+
+def _mlp(d_in, d_hid, d_out):
+ return nn.Sequential(nn.Linear(d_in, d_hid), nn.BatchNorm1d(d_hid), nn.ReLU(),
+ nn.Linear(d_hid, d_out))
+
+
+class GIN(nn.Module):
+ """Sum aggregation + MLP update -> injective on multisets -> matches 1-WL."""
+ def __init__(self, in_dim, hidden=64, layers=4, out_dim=10):
+ super().__init__()
+ self.convs = nn.ModuleList()
+ d = in_dim
+ for _ in range(layers):
+ self.convs.append(GINConv(_mlp(d, hidden, hidden), train_eps=True))
+ d = hidden
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
+
+ def forward(self, x, edge_index, batch):
+ for conv in self.convs:
+ x = conv(x, edge_index).relu()
+ return self.head(global_add_pool(x, batch))
+
+
+class GCN(nn.Module):
+ """Mean (normalized) aggregation -> non-injective -> strictly below 1-WL (reference baseline)."""
+ def __init__(self, in_dim, hidden=64, layers=4, out_dim=10):
+ super().__init__()
+ self.convs = nn.ModuleList()
+ d = in_dim
+ for _ in range(layers):
+ self.convs.append(GCNConv(d, hidden))
+ d = hidden
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
+
+ def forward(self, x, edge_index, batch):
+ for conv in self.convs:
+ x = conv(x, edge_index).relu()
+ return self.head(global_mean_pool(x, batch))
diff --git a/diag/peptides_depth.py b/diag/peptides_depth.py
new file mode 100644
index 0000000..d751b31
--- /dev/null
+++ b/diag/peptides_depth.py
@@ -0,0 +1,66 @@
+"""Depth-resolution analysis on LRGB Peptides-struct (long-range, real, large graphs).
+
+Best-achievable MSE for a depth-L MPNN == within-(L-round-WL-color) variance of the target.
+The curve over L localizes failure cause WITHOUT training:
+ floor(L=small) high -> under-reached signal still hidden
+ floor(L) - floor(converged) -> H1a: depth-recoverable (more iteration/depth helps)
+ floor(converged) -> H2 : 1-WL ceiling (irreducible by any MPNN)
+Targets are z-scored per dim, so floor is the fraction of variance unexplained (var=1 baseline).
+"""
+import numpy as np
+from collections import defaultdict
+from torch_geometric.datasets import LRGBDataset
+from diag import wl
+
+S = 4000 # graph subsample (for tractable pure-python WL)
+MAX_ROUNDS = 40 # cap (>> typical GIN depth; chains converge near their diameter)
+
+
+def floor_at(ghist, Y):
+ groups = defaultdict(list)
+ for i, c in enumerate(ghist):
+ groups[c].append(i)
+ sse = 0.0
+ for idxs in groups.values():
+ yy = Y[idxs]
+ sse += ((yy - yy.mean(0)) ** 2).sum()
+ return sse / (len(ghist) * Y.shape[1]), len(groups)
+
+
+def main():
+ ds = LRGBDataset(root='/home/yurenh2/rrog/data/lrgb', name='Peptides-struct', split='train')
+ graphs = [ds[i] for i in range(min(S, len(ds)))]
+ fmap = {}
+ def fid(row):
+ t = tuple(row)
+ if t not in fmap:
+ fmap[t] = len(fmap)
+ return fmap[t]
+ adjs, inits, Y = [], [], []
+ for g in graphs:
+ adjs.append(wl.edges_to_adj(g.num_nodes, g.edge_index.numpy()))
+ inits.append(np.array([fid(r) for r in g.x.tolist()], dtype=np.int64))
+ Y.append(g.y.numpy().reshape(-1))
+ Y = np.stack(Y).astype(np.float64)
+ Y = (Y - Y.mean(0)) / (Y.std(0) + 1e-8) # z-score per target
+ print(f"subsample={len(graphs)} graphs, {len(fmap)} distinct node-feature ids, targets={Y.shape[1]}")
+
+ import time; t0 = time.time()
+ node_rounds, ghist_rounds, conv = wl.wl_refine(adjs, inits=inits, max_rounds=MAX_ROUNDS)
+ print(f"WL refined to round {conv} (cap {MAX_ROUNDS}) in {time.time()-t0:.1f}s")
+
+ print(f"{'L':>4} {'floor_MSE(std)':>14} {'%var_unexpl':>12} {'#graph_colors':>14}")
+ floors = {}
+ for L in [0, 1, 2, 3, 4, 5, 8, 16, 32, conv]:
+ r = min(L, conv)
+ f, nc = floor_at(ghist_rounds[r], Y)
+ floors[L] = f
+ print(f"{L:>4} {f:>14.4f} {100*f:>11.1f}% {nc:>14}")
+ h2 = floors[conv]
+ for Lg in [4, 5]:
+ print(f"\nAt GIN depth L={Lg}: H2 ceiling={h2:.3f} | depth-recoverable H1a (floor[{Lg}]-H2)"
+ f"={floors[Lg]-h2:.3f} | already-reachable={1-floors[Lg]:.3f} of var")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/ppgn_color.py b/diag/ppgn_color.py
new file mode 100644
index 0000000..25db01b
--- /dev/null
+++ b/diag/ppgn_color.py
@@ -0,0 +1,209 @@
+"""H2: Recursive PPGN (higher-order, 3-WL) backbone on graph 3-coloring.
+
+State is a pair-tensor X[i,j,:] (not node features). The powerful block multiplies two
+channel-wise n x n matrices (the 3-WL operation): P = M1 @ M2; out = m3([X, P]). Recurse the
+pair-tensor (TRM-style deep supervision), pool to nodes (mean over j), decode colors.
+Self-contained: trains (EMA, best solve), then LE diagnostic + PTRM noise/lambda-select;
+writes color_/le_/ptrm_ JSONs (conv='ppgn') so diag/aggregate.py folds it into the big table.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ppgn_color.py --grad_mode full --seed 0
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diag.train_color import make_split, featurize, OUT
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+
+def dense_A(edge_index, n):
+ A = torch.zeros(n, n)
+ if edge_index.shape[1]:
+ A[edge_index[0], edge_index[1]] = 1.0
+ return torch.maximum(A, A.t())
+
+
+def mlp(di, dh, do):
+ return nn.Sequential(nn.Linear(di, dh), nn.ReLU(), nn.Linear(dh, do))
+
+
+class RecPPGN(nn.Module):
+ def __init__(self, in_dim, hidden=64, k=3, T=3, n_sup=3, grad_mode='full', sigma=0.0):
+ super().__init__()
+ self.node_in = nn.Linear(in_dim, hidden)
+ self.adj_emb = nn.Embedding(2, hidden)
+ self.m1 = mlp(hidden, hidden, hidden); self.m2 = mlp(hidden, hidden, hidden)
+ self.m3 = mlp(2 * hidden, hidden, hidden)
+ self.ln = nn.LayerNorm(hidden)
+ self.head = nn.Linear(hidden, k)
+ self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma
+
+ def init_pair(self, xin, A): # xin [B,n,in], A [B,n,n]
+ h = self.node_in(xin)
+ X = h.unsqueeze(2) + h.unsqueeze(1) # [B,n,n,hidden]
+ return X + self.adj_emb(A.long())
+
+ def block(self, X): # X [B,n,n,h]
+ M1, M2 = self.m1(X), self.m2(X)
+ P = torch.einsum('bikc,bkjc->bijc', M1, M2) / X.shape[1]
+ return self.ln(self.m3(torch.cat([X, P], dim=-1)))
+
+ def _inner(self, X, X0, noise):
+ X = self.block(X + X0)
+ if noise and self.sigma > 0:
+ X = X + self.sigma * torch.randn_like(X)
+ return X
+
+ def recurse(self, X, X0, noise, one_step=False):
+ if one_step:
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ X = self._inner(X, X0, noise)
+ X = X.detach()
+ return self._inner(X, X0, noise)
+ for _ in range(self.T):
+ X = self._inner(X, X0, noise)
+ return X
+
+ def forward(self, xin, A, noise=False):
+ X0 = self.init_pair(xin, A)
+ X = torch.zeros_like(X0)
+ outs = []
+ for s in range(self.n_sup):
+ X = self.recurse(X, X0, noise, one_step=(self.grad_mode == '1step'))
+ outs.append(self.head(X.mean(2))) # pool over j -> [B,n,k]
+ X = X.detach()
+ return outs
+
+
+def conflict_loss(logits, A): # logits [B,n,k], A [B,n,n]
+ p = F.softmax(logits, dim=-1)
+ return (torch.einsum('bik,bjk->bij', p, p) * A).sum() / (A.sum() + 1e-9)
+
+
+def batches(graphs, bs, shuffle, dev):
+ idx = np.arange(len(graphs))
+ if shuffle:
+ np.random.shuffle(idx)
+ for s in range(0, len(idx) - (len(idx) % bs if shuffle else 0), bs):
+ sel = idx[s:s + bs]
+ if len(sel) == 0:
+ continue
+ X = torch.stack([graphs[i]['xin'] for i in sel]).to(dev)
+ A = torch.stack([graphs[i]['A'] for i in sel]).to(dev)
+ yield X, A
+
+
+@torch.no_grad()
+def solve_stats(model, graphs, dev, sample=300):
+ model.eval(); solved = 0; tot = 0
+ for g in graphs[:sample]:
+ A = g['A'].unsqueeze(0).to(dev)
+ col = model(g['xin'].unsqueeze(0).to(dev), A)[-1][0].argmax(-1)
+ conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2
+ solved += int(conf == 0); tot += 1
+ return solved / tot
+
+
+def lyap1_and_solved(model, g, dev, seed, sigma=0.0):
+ gen = torch.Generator(device=dev).manual_seed(seed)
+ xin = g['xin'].unsqueeze(0).to(dev); A = g['A'].unsqueeze(0).to(dev)
+ X0 = model.init_pair(xin, A)
+ X = torch.zeros_like(X0)
+ v = torch.randn(X.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12)
+ lam = 0.0
+ for _ in range(model.n_sup * model.T):
+ Xd, Jv = torch.autograd.functional.jvp(lambda XX: model.block(XX + X0), X, v)
+ nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
+ X = Xd.detach()
+ if sigma > 0:
+ X = X + sigma * torch.randn(X.shape, generator=gen, device=dev)
+ col = model.head(X.mean(2))[0].argmax(-1)
+ conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2
+ return lam / (model.n_sup * model.T), conf
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--bs', type=int, default=16)
+ ap.add_argument('--hidden', type=int, default=64); ap.add_argument('--seed', type=int, default=0)
+ ap.add_argument('--sigma', type=float, default=0.2); ap.add_argument('--K', type=int, default=16)
+ ap.add_argument('--n_graphs', type=int, default=150)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ for g in tr + te:
+ g['A'] = dense_A(g['edge_index'], g['n'])
+ in_dim = tr[0]['xin'].shape[1]
+
+ model = RecPPGN(in_dim, args.hidden, 3, grad_mode=args.grad_mode).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ t0 = time.time(); best = -1; best_state = None
+ for ep in range(args.epochs):
+ model.train()
+ for X, A in batches(tr, args.bs, True, dev):
+ opt.zero_grad()
+ outs = model(X, A, noise=False)
+ loss = sum(conflict_loss(o, A) for o in outs) / len(outs)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr = solve_stats(model, te, dev)
+ if sr > best:
+ best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(bk)
+ print(f"ep{ep+1} solve={sr:.3f}", flush=True)
+ model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval()
+
+ # LE + PTRM (noise + lambda-select) on test
+ lams, fails = [], []
+ passk = lamsel = rand = 0
+ Lr, Sr = [], []
+ for gi, g in enumerate(te[:args.n_graphs]):
+ lam0, c0 = lyap1_and_solved(model, g, dev, seed=gi, sigma=0.0)
+ lams.append(lam0); fails.append(int(c0 > 0))
+ res = [lyap1_and_solved(model, g, dev, seed=1000 * gi + j, sigma=args.sigma) for j in range(args.K)]
+ confs = np.array([c for _, c in res]); rl = np.array([l for l, _ in res])
+ solved = confs == 0
+ passk += int(solved.any()); lamsel += int(solved[rl.argmin()]); rand += int(solved[0])
+ Lr += rl.tolist(); Sr += solved.tolist()
+ lams, fails = np.array(lams), np.array(fails)
+ s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ Lr, Sr = np.array(Lr), np.array(Sr)
+ pauc = (roc_auc_score(Sr.astype(int), -Lr) if roc_auc_score and Sr.any() and (~Sr).any() else float('nan'))
+ n = len(te[:args.n_graphs])
+ print(f"[ppgn/{args.grad_mode}] solve={best:.3f} | LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} | "
+ f"PTRM det={1 - fails.mean():.3f} passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)")
+
+ base = f"ppgn_full_none_n50_k3_p0.2_T3_ns3_s{args.seed}" if args.grad_mode == 'full' \
+ else f"ppgn_1step_none_n50_k3_p0.2_T3_ns3_s{args.seed}"
+ com = {'conv': 'ppgn', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed}
+ json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'), indent=2)
+ json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean()),
+ 'lam_solved': (float(s.mean()) if len(s) else None),
+ 'lam_unsolved': (float(f.mean()) if len(f) else None)},
+ open(os.path.join(OUT, f"le_color_{base}.json"), 'w'), indent=2)
+ json.dump({**com, 'det': 1 - float(fails.mean()),
+ 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n,
+ 'perRoll': float(Sr.mean()), 'auroc': float(pauc)}}},
+ open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'), indent=2)
+ print(" wrote color_/le_color_/ptrm_color_", base)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/ptrm_color.py b/diag/ptrm_color.py
new file mode 100644
index 0000000..4004297
--- /dev/null
+++ b/diag/ptrm_color.py
@@ -0,0 +1,85 @@
+"""Step-2(a): PTRM test-time noise + lambda-based selection on a trained coloring model
+(any backbone feature set via cfg pe). Writes a JSON per ckpt for multi-seed aggregation.
+
+deterministic / pass@K (conflict-min, ground truth) / lambda-select (min lambda1) / random.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_...pt
+"""
+import argparse, json, os
+import numpy as np
+import torch
+from diag.train_color import RecGINColor, make_split, featurize
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+OUT = '/home/yurenh2/rrog/runs'
+
+
+def rollout(model, xin, ei, sigma, n_sup, T, dev, seed):
+ gen = torch.Generator(device=dev).manual_seed(seed)
+ h0 = model.lin_in(xin)
+ z = torch.zeros_like(h0)
+ v = torch.randn(h0.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12)
+ def step(zz):
+ return model.block(zz + h0, ei)
+ lam = 0.0
+ for _ in range(n_sup * T):
+ z_det, Jv = torch.autograd.functional.jvp(step, z, v)
+ nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
+ z = z_det.detach()
+ if sigma > 0:
+ z = z + sigma * torch.randn(z.shape, generator=gen, device=dev)
+ lam /= (n_sup * T)
+ col = model.head(z).argmax(-1)
+ conf = (col[ei[0]] == col[ei[1]]).sum().item() // 2
+ return conf, lam
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--ckpt', required=True)
+ ap.add_argument('--K', type=int, default=16)
+ ap.add_argument('--n_graphs', type=int, default=150)
+ ap.add_argument('--sigmas', type=float, nargs='+', default=[0.05, 0.1, 0.2, 0.4])
+ args = ap.parse_args()
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg']
+ deg = torch.tensor(c['deg']) if c.get('deg') else None
+ model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'],
+ grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev)
+ model.load_state_dict(ck['state']); model.eval()
+ nsup, T = c['n_sup'], c['T']
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), c.get('pe', 'none'), c.get('rwse_k', 16))
+ te = te[:args.n_graphs]; n = len(te)
+
+ det = sum(rollout(model, r['xin'].to(dev), r['edge_index'].to(dev), 0.0, nsup, T, dev, 0)[0] == 0
+ for r in te) / n
+ out = {'conv': c.get('conv', 'gin'), 'pe': c.get('pe', 'none'), 'seed': c.get('seed'),
+ 'grad_mode': c['grad_mode'], 'contract': c.get('contract', False), 'det': det, 'sigmas': {}}
+ print(f"[pe={out['pe']} s{out['seed']}] deterministic solve_rate = {det:.3f} (n={n}, K={args.K})")
+ print(f"{'sigma':>6} {'pass@K':>8} {'lam-sel':>8} {'random':>8} {'perRoll':>8} {'AUROC(s|-lam)':>14}")
+ for sigma in args.sigmas:
+ passk = lamsel = rand = 0
+ L, S = [], []
+ for gi, r in enumerate(te):
+ xin = r['xin'].to(dev); ei = r['edge_index'].to(dev)
+ res = [rollout(model, xin, ei, sigma, nsup, T, dev, 1000 * gi + j) for j in range(args.K)]
+ confs = np.array([c0 for c0, _ in res]); lams = np.array([l for _, l in res])
+ solved = confs == 0
+ passk += int(solved.any()); lamsel += int(solved[lams.argmin()]); rand += int(solved[0])
+ L += lams.tolist(); S += solved.tolist()
+ L, S = np.array(L), np.array(S)
+ auc = (roc_auc_score(S.astype(int), -L) if roc_auc_score and S.any() and (~S).any() else float('nan'))
+ out['sigmas'][str(sigma)] = {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n,
+ 'perRoll': float(S.mean()), 'auroc': float(auc)}
+ print(f"{sigma:>6} {passk/n:>8.3f} {lamsel/n:>8.3f} {rand/n:>8.3f} {S.mean():>8.3f} {auc:>14.3f}")
+
+ base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
+ with open(os.path.join(OUT, f"ptrm_{base}.json"), 'w') as f:
+ json.dump(out, f, indent=2)
+ print(" wrote", os.path.join(OUT, f"ptrm_{base}.json"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/run_archA.sh b/diag/run_archA.sh
new file mode 100644
index 0000000..5385152
--- /dev/null
+++ b/diag/run_archA.sh
@@ -0,0 +1,16 @@
+#!/usr/bin/env bash
+# Conv axis (pe=none): gcn, sage, gat. 5 seeds, train+LE+PTRM. (pin GPU at launch.)
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "A start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ for conv in gcn sage gat; do
+ ck=runs/ckpt_color_${conv}_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "== A s$s conv=$conv =="
+ python3 diag/train_color.py --mode train --conv "$conv" --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train $conv s$s"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le $conv s$s"
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $conv s$s"
+ done
+done
+echo "doneA=$(date -Is)"
diff --git a/diag/run_archB.sh b/diag/run_archB.sh
new file mode 100644
index 0000000..317638f
--- /dev/null
+++ b/diag/run_archB.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env bash
+# GPS transformer backbone + feature axis (gin + lappe / all). 5 seeds. (pin GPU at launch.)
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "B start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ ck=runs/ckpt_color_gps_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "== B s$s conv=gps =="
+ python3 diag/train_color.py --mode train --conv gps --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train gps s$s"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le gps s$s"
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm gps s$s"
+ for pe in lappe all; do
+ ck2=runs/ckpt_color_gin_full_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "== B s$s gin pe=$pe =="
+ python3 diag/train_color.py --mode train --conv gin --pe "$pe" --p 0.2 --epochs 150 --seed "$s" || echo "!! train $pe s$s"
+ python3 diag/train_color.py --mode le --ckpt "$ck2" || echo "!! le $pe s$s"
+ python3 diag/ptrm_color.py --ckpt "$ck2" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $pe s$s"
+ done
+done
+echo "doneB=$(date -Is)"
diff --git a/diag/run_cin.sh b/diag/run_cin.sh
new file mode 100644
index 0000000..6be362d
--- /dev/null
+++ b/diag/run_cin.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+set -uo pipefail
+cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog
+echo "CIN start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ echo "== cin s$s =="; python3 diag/cin_color.py --grad_mode full --seed "$s" || echo "!! cin s$s"
+done
+echo "doneCIN=$(date -Is)"
diff --git a/diag/run_color.sh b/diag/run_color.sh
new file mode 100644
index 0000000..ad3c406
--- /dev/null
+++ b/diag/run_color.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+# Step-2 (TRM regime, large output): graph 3-coloring, TRM full vs HRM 1-step + LE diagnostic.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+for gm in full 1step; do
+ echo "===== train $gm ====="
+ python3 diag/train_color.py --mode train --grad_mode "$gm" --p 0.2 --epochs 150 --seed 0 \
+ || echo "!! train $gm failed"
+done
+echo "===== LE diagnostic (lambda1: solved vs unsolved) ====="
+python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_full_n50_k3_p0.2_T3_ns3_s0.pt || echo "!! le full failed"
+python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_1step_n50_k3_p0.2_T3_ns3_s0.pt || echo "!! le 1step failed"
+echo "done=$(date -Is)"
diff --git a/diag/run_cycle.sh b/diag/run_cycle.sh
new file mode 100644
index 0000000..f85a6fd
--- /dev/null
+++ b/diag/run_cycle.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+# Ring-counting (ZINC, [#5-cycles,#6-cycles]) diagnosis: does a real 1-WL ceiling exist,
+# and does noise (RNI) vs structured >1-WL (RWSE) break it?
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+run() { echo "===== $* ====="; python3 diag/train_cycle.py "$@" --layers 5 --epochs 200 --seed 0 || echo "!! FAILED $*"; }
+run --conv gin --feat none # 1-WL baseline (should fail to count)
+run --conv gcn --feat none # sub-1-WL reference
+run --conv gin --feat rni # noise / PTRM-style crude symmetry break
+run --conv gin --feat rwse # structured >1-WL positive control
+echo "done=$(date -Is)"
diff --git a/diag/run_diag.sh b/diag/run_diag.sh
new file mode 100644
index 0000000..4a6f31a
--- /dev/null
+++ b/diag/run_diag.sh
@@ -0,0 +1,16 @@
+#!/usr/bin/env bash
+# Step-1 diagnosis sweep. CSL = pure-H2 anchor (classification).
+# Triangle counting on ER (graphs 1-WL-distinguishable -> H2~0, the H1 end) and
+# regular (graphs collapse to one WL color -> H2~var, the ceiling end).
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+run() { echo "===== $* ====="; python3 diag/train_diag.py "$@" --layers 4 --epochs 300 --seed 0 \
+ || echo "!! FAILED $*"; }
+for model in gin gcn; do
+ run --task csl --model "$model"
+ run --task tri --model "$model" --kind er
+ run --task tri --model "$model" --kind regular
+done
+echo "done=$(date -Is)"
diff --git a/diag/run_esan.sh b/diag/run_esan.sh
new file mode 100644
index 0000000..f87b379
--- /dev/null
+++ b/diag/run_esan.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+set -uo pipefail
+cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog
+echo "ESAN start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ echo "== esan s$s =="; python3 diag/esan_color.py --grad_mode full --seed "$s" || echo "!! esan s$s"
+done
+echo "doneESAN=$(date -Is)"
diff --git a/diag/run_ff.sh b/diag/run_ff.sh
new file mode 100644
index 0000000..e2b7b0a
--- /dev/null
+++ b/diag/run_ff.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+set -uo pipefail
+cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog
+echo "FF start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for conv in gcn gat; do for L in 4 8 16; do for s in 0 1 2; do
+ python3 diag/ff_color.py --conv "$conv" --L "$L" --seed "$s" || echo "!! ff $conv L$L s$s"
+done; done; done
+echo "doneFF=$(date -Is)"
diff --git a/diag/run_le.sh b/diag/run_le.sh
new file mode 100644
index 0000000..8fc31ea
--- /dev/null
+++ b/diag/run_le.sh
@@ -0,0 +1,12 @@
+#!/usr/bin/env bash
+# Step-2(iii): TRM-ish GIN full-recursion vs 1-step-gradient, then LE diagnostic on each.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 --epochs 200 --seed 0 || echo "!! train full failed"
+python3 diag/train_rec.py --grad_mode 1step --sigma 0 --K 1 --epochs 200 --seed 0 || echo "!! train 1step failed"
+echo "===== LE diagnostic (lambda1: success vs failure) ====="
+python3 diag/lyap.py --ckpt runs/ckpt_rec_full_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap full failed"
+python3 diag/lyap.py --ckpt runs/ckpt_rec_1step_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap 1step failed"
+echo "done=$(date -Is)"
diff --git a/diag/run_pe.sh b/diag/run_pe.sh
new file mode 100644
index 0000000..8350160
--- /dev/null
+++ b/diag/run_pe.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+# Roadmap #1: RRoG on PE-augmented backbone. GIN vs GIN+RWSE on coloring:
+# does RRoG-noise add headroom on top of static structural encoding, or is it redundant?
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+for pe in none rwse; do
+ echo "===== train full pe=$pe ====="
+ python3 diag/train_color.py --mode train --grad_mode full --pe "$pe" --p 0.2 --epochs 150 --seed 0 \
+ || echo "!! train $pe failed"
+done
+echo "===== LE (full, both pe) ====="
+for pe in none rwse; do
+ python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \
+ || echo "!! le $pe failed"
+done
+echo "===== PTRM noise + lambda-select (both pe) ====="
+for pe in none rwse; do
+ echo "--- pe=$pe ---"
+ python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \
+ --K 16 --n_graphs 150 --sigmas 0.1 0.2 0.4 || echo "!! ptrm $pe failed"
+done
+echo "done=$(date -Is)"
diff --git a/diag/run_pe2.sh b/diag/run_pe2.sh
new file mode 100644
index 0000000..db7bd8a
--- /dev/null
+++ b/diag/run_pe2.sh
@@ -0,0 +1,18 @@
+#!/usr/bin/env bash
+# Roadmap #2: RRoG on a GSN-style motif backbone (per-node K3/wedge substructure counts).
+# full-recursion, 5 seeds; train + LE + PTRM(sigma=0.2). Aggregate vs none/rwse.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+for s in 0 1 2 3 4; do
+ ck=runs/ckpt_color_full_gsn_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "===== seed=$s full gsn ====="
+ python3 diag/train_color.py --mode train --grad_mode full --pe gsn --p 0.2 --epochs 150 --seed "$s" \
+ || echo "!! train gsn s$s failed"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le gsn s$s failed"
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm gsn s$s failed"
+done
+echo "===== AGGREGATE (all pe: none/rwse/gsn) ====="
+python3 diag/aggregate.py
+echo "done=$(date -Is)"
diff --git a/diag/run_pe3.sh b/diag/run_pe3.sh
new file mode 100644
index 0000000..a978393
--- /dev/null
+++ b/diag/run_pe3.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+# Roadmap #3 (subgraph/ego features) + #4 (IGNN-style forced contraction), 5 seeds.
+# #3: full --pe sub. #4: full --pe none --contract (vs free full/none baseline).
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+for s in 0 1 2 3 4; do
+ ck=runs/ckpt_color_full_sub_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "===== seed=$s #3 full sub ====="
+ python3 diag/train_color.py --mode train --grad_mode full --pe sub --p 0.2 --epochs 150 --seed "$s" || echo "!! train sub s$s failed"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le sub s$s failed"
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm sub s$s failed"
+
+ ck2=runs/ckpt_color_full_none_ctr_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "===== seed=$s #4 full none --contract ====="
+ python3 diag/train_color.py --mode train --grad_mode full --pe none --contract --p 0.2 --epochs 150 --seed "$s" || echo "!! train ctr s$s failed"
+ python3 diag/train_color.py --mode le --ckpt "$ck2" || echo "!! le ctr s$s failed"
+ python3 diag/ptrm_color.py --ckpt "$ck2" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm ctr s$s failed"
+done
+echo "===== AGGREGATE ====="
+python3 diag/aggregate.py
+echo "done=$(date -Is)"
diff --git a/diag/run_pna.sh b/diag/run_pna.sh
new file mode 100644
index 0000000..897315d
--- /dev/null
+++ b/diag/run_pna.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "PNA start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ ck=runs/ckpt_color_pna_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "== pna s$s =="
+ python3 diag/train_color.py --mode train --conv pna --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train pna s$s"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le pna s$s"
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm pna s$s"
+done
+echo "donePNA=$(date -Is)"
diff --git a/diag/run_ppgn.sh b/diag/run_ppgn.sh
new file mode 100644
index 0000000..2cc27e9
--- /dev/null
+++ b/diag/run_ppgn.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+set -uo pipefail
+cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog
+echo "PPGN start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ echo "== ppgn s$s =="; python3 diag/ppgn_color.py --grad_mode full --seed "$s" || echo "!! ppgn s$s"
+done
+echo "donePPGN=$(date -Is)"
diff --git a/diag/run_real.sh b/diag/run_real.sh
new file mode 100644
index 0000000..c29426a
--- /dev/null
+++ b/diag/run_real.sh
@@ -0,0 +1,12 @@
+#!/usr/bin/env bash
+# Peptides-struct training diagnosis: depth sweep + RNI(noise/>1-WL) + GCN(<1-WL) reference.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+run() { echo "===== $* ====="; python3 diag/train_real.py "$@" --epochs 150 --seed 0 || echo "!! FAILED $*"; }
+run --conv gin --layers 5 --rni 0 # baseline 1-WL
+run --conv gin --layers 10 --rni 0 # depth / long-range
+run --conv gin --layers 5 --rni 16 # noise / beyond-1-WL ceiling test
+run --conv gcn --layers 5 --rni 0 # sub-1-WL reference
+echo "done=$(date -Is)"
diff --git a/diag/run_rec.sh b/diag/run_rec.sh
new file mode 100644
index 0000000..632498d
--- /dev/null
+++ b/diag/run_rec.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+# Step-2: recursive GNN + full PTRM on ZINC ring-counting. Does per-step noise + best-Q@K
+# selection break the 1-WL counting ceiling that input-RNI couldn't? Averaging vs selection.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+run() { echo "===== $* ====="; python3 diag/train_rec.py "$@" --epochs 200 --seed 0 || echo "!! FAILED $*"; }
+run --sigma 0 --K 1 --select bestq # deterministic recursive baseline (~1-WL ceiling)
+run --sigma 0.1 --K 8 --select none # averaging over rollouts (RNI-style null)
+run --sigma 0.1 --K 8 --select bestq # PTRM-proper: per-step noise + best-Q@K selection
+run --sigma 0.2 --K 16 --select bestq # scaled noise + rollouts
+echo "done=$(date -Is)"
diff --git a/diag/run_seeds.sh b/diag/run_seeds.sh
new file mode 100644
index 0000000..77f22c8
--- /dev/null
+++ b/diag/run_seeds.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+# Harden coloring results with 5 seeds: full-vs-1step solve, LE AUROC, PTRM(sigma=0.2) for none/rwse.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+for s in 0 1 2 3 4; do
+ for cfg in "full none" "full rwse" "1step none"; do
+ set -- $cfg; gm=$1; pe=$2
+ ck=runs/ckpt_color_${gm}_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "===== seed=$s $gm $pe ====="
+ python3 diag/train_color.py --mode train --grad_mode "$gm" --pe "$pe" --p 0.2 --epochs 150 --seed "$s" \
+ || echo "!! train $gm $pe s$s failed"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le $gm $pe s$s failed"
+ if [ "$gm" = "full" ]; then
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $pe s$s failed"
+ fi
+ done
+done
+echo "===== AGGREGATE ====="
+python3 diag/aggregate.py
+echo "done=$(date -Is)"
diff --git a/diag/selftest_wl.py b/diag/selftest_wl.py
new file mode 100644
index 0000000..6c08310
--- /dev/null
+++ b/diag/selftest_wl.py
@@ -0,0 +1,53 @@
+"""Validate the 1-WL instrument on canonical graphs BEFORE trusting any decomposition.
+Run: PYTHONPATH=/home/yurenh2/rrog python3 /home/yurenh2/rrog/diag/selftest_wl.py
+"""
+import numpy as np
+from diag import wl, datasets
+
+
+def run():
+ P = datasets.canonical_pairs()
+ names = list(P.keys())
+ adjs = [wl.edges_to_adj(P[k]['n'], P[k]['edge_index']) for k in names]
+ node_rounds, ghist_rounds, conv = wl.wl_refine(adjs)
+ color = {names[i]: ghist_rounds[conv][i] for i in range(len(names))}
+ print("converged round:", conv)
+ for k in names:
+ print(f" {k:5s} tri={P[k]['tri']} wlcolor={color[k]}")
+
+ # (1) C6 == 2C3 under 1-WL (both 2-regular) yet differ in triangles -> counting is H2
+ assert color['C6'] == color['2C3'], "C6 vs 2C3 should be 1-WL-equal"
+ assert P['C6']['tri'] != P['2C3']['tri']
+ # (2) P4 != K1,3 (different degree multiset)
+ assert color['P4'] != color['K1,3'], "P4 vs K1,3 should be 1-WL-distinct"
+ print("OK canonical: C6==2C3 (WL blind to triangles), P4!=K1,3")
+
+ # (3) regression H2 floor on {C6,2C3} == target variance
+ sub = [wl.edges_to_adj(P['C6']['n'], P['C6']['edge_index']),
+ wl.edges_to_adj(P['2C3']['n'], P['2C3']['edge_index'])]
+ y = [P['C6']['tri'], P['2C3']['tri']]
+ _, gh, cv = wl.wl_refine(sub)
+ dec = wl.decompose_regression(gh, cv, L=10, y=y, train_idx=[0, 1], eval_idx=[0, 1])
+ print(f" triangle-count H2 floor MSE on {{C6,2C3}} = {dec['mse_floor_oracle_H2']:.4f} "
+ f"(target var = {np.var(y):.4f})")
+ assert abs(dec['mse_floor_oracle_H2'] - np.var(y)) < 1e-9
+
+ # (4) CSL: 4-regular -> 1 node color, 1 graph color, WL-optimal acc = chance (0.1) -> 100% H2
+ csl = datasets.build_csl(n_per_class=15, seed=0)
+ adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in csl]
+ nr, gh, cv = wl.wl_refine(adjs)
+ n_node_colors = len(set(nr[cv][0].tolist()))
+ n_graph_colors = len(set(gh[cv]))
+ y = [d['y'] for d in csl]
+ idx = list(range(len(csl)))
+ att = wl.attribute_classification(gh, cv, L=4, y=y, train_idx=idx, eval_idx=idx)
+ print(f" CSL: node-colors={n_node_colors}, distinct graph-colors={n_graph_colors}, "
+ f"WL-optimal acc={att['wl_optimal_acc_converged']:.3f} (chance 0.1), buckets={att['counts']}")
+ assert n_node_colors == 1 and n_graph_colors == 1
+ assert abs(att['wl_optimal_acc_converged'] - 0.1) < 1e-6
+ assert att['counts'].get('H2', 0) == len(csl)
+ print("OK CSL: fully 1-WL-collapsed -> 100% of failures are H2. Instrument VALIDATED.")
+
+
+if __name__ == "__main__":
+ run()
diff --git a/diag/train_color.py b/diag/train_color.py
new file mode 100644
index 0000000..36f8496
--- /dev/null
+++ b/diag/train_color.py
@@ -0,0 +1,347 @@
+"""Recursive (TRM-ish) GNN graph 3-coloring with swappable BACKBONE for the RRoG roadmap.
+
+--conv gin|gcn|sage|gat|gps : message-passing operator (gps = GraphGPS local MPNN + global
+ attention = TRM's original transformer backbone, on the graph).
+--pe none|rwse|gsn|sub|lappe|all : input structural features (random sym-break [+ encoding]).
+--contract : reverse-flossing lambda-penalty during training (force contraction; roadmap #4).
+--grad_mode full|1step : TRM full recursion vs HRM 1-step gradient.
+Self-supervised conflict/Potts loss; success = zero-conflict; EMA; deep supervision.
+Modes: --mode train (saves ckpt + JSON) / --mode le.
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import GINConv, GCNConv, SAGEConv, GATConv, GPSConv, PNAConv
+from torch_geometric.utils import degree as _pyg_degree
+
+# GPS uses scaled_dot_product_attention; force the MATH kernel so torch.autograd.functional.jvp
+# (LE diagnostic / PTRM rollouts) has a double-backward-able implementation.
+for _f in ('enable_flash_sdp', 'enable_mem_efficient_sdp', 'enable_math_sdp'):
+ try:
+ getattr(torch.backends.cuda, _f)(_f == 'enable_math_sdp')
+ except Exception:
+ pass
+
+OUT = '/home/yurenh2/rrog/runs'
+CACHE = '/home/yurenh2/rrog/data/color_cache'
+
+
+def gen(n, k, p, r, seed):
+ rng = np.random.default_rng(seed)
+ part = rng.integers(0, k, n)
+ src, dst = [], []
+ for i in range(n):
+ for j in range(i + 1, n):
+ if part[i] != part[j] and rng.random() < p:
+ src += [i, j]; dst += [j, i]
+ ei = torch.tensor([src, dst], dtype=torch.long) if src else torch.zeros((2, 0), dtype=torch.long)
+ rf = torch.tensor(rng.standard_normal((n, r)), dtype=torch.float)
+ return {'n': n, 'edge_index': ei, 'rfeat': rf}
+
+
+def make_split(split, n, k, p, r, count, seed0):
+ os.makedirs(CACHE, exist_ok=True)
+ fp = os.path.join(CACHE, f"{split}_n{n}_k{k}_p{p}_r{r}.pt")
+ if os.path.exists(fp):
+ return torch.load(fp, weights_only=False)
+ data = [gen(n, k, p, r, seed0 + i) for i in range(count)]
+ torch.save(data, fp)
+ return data
+
+
+def _adj(edge_index, n):
+ A = np.zeros((n, n), dtype=np.float64)
+ ei = edge_index.numpy()
+ if ei.shape[1]:
+ A[ei[0], ei[1]] = 1.0
+ return np.maximum(A, A.T)
+
+
+def rwse(edge_index, n, K):
+ A = _adj(edge_index, n); deg = A.sum(1)
+ P = A / np.where(deg > 0, deg, 1.0)[:, None]
+ out = np.zeros((n, K), dtype=np.float32); M = np.eye(n)
+ for j in range(K):
+ M = M @ P; out[:, j] = np.diag(M)
+ return torch.from_numpy(out)
+
+
+def gsn_feats(edge_index, n):
+ A = _adj(edge_index, n); deg = A.sum(1)
+ tri = (A @ A @ A).diagonal() / 2.0
+ wedge = deg * (deg - 1) / 2.0
+ return torch.tensor(np.stack([np.log1p(tri), np.log1p(wedge)], axis=1), dtype=torch.float)
+
+
+def sub_feats(edge_index, n):
+ A = _adj(edge_index, n); deg = A.sum(1); A2 = A @ A
+ M = (A2 > 0).astype(np.float64) * (A == 0).astype(np.float64); np.fill_diagonal(M, 0.0)
+ d2 = M.sum(1)
+ tri = (A @ A2).diagonal() / 2.0
+ clus = np.where(deg > 1, tri / (deg * (deg - 1) / 2.0), 0.0)
+ return torch.tensor(np.stack([np.log1p(deg), np.log1p(d2), clus], axis=1), dtype=torch.float)
+
+
+def lappe_feats(edge_index, n, kpe=8):
+ A = _adj(edge_index, n); deg = A.sum(1)
+ di = np.where(deg > 0, 1.0 / np.sqrt(deg), 0.0)
+ L = np.eye(n) - di[:, None] * A * di[None, :]
+ _, V = np.linalg.eigh(L)
+ pe = V[:, 1:kpe + 1]
+ if pe.shape[1] < kpe:
+ pe = np.pad(pe, ((0, 0), (0, kpe - pe.shape[1])))
+ return torch.tensor(pe, dtype=torch.float)
+
+
+def featurize(graphs, pe, rwse_k):
+ def feat(g):
+ ei, n = g['edge_index'], g['n']
+ if pe == 'rwse': return rwse(ei, n, rwse_k)
+ if pe == 'gsn': return gsn_feats(ei, n)
+ if pe == 'sub': return sub_feats(ei, n)
+ if pe == 'lappe': return lappe_feats(ei, n)
+ if pe == 'all': return torch.cat([rwse(ei, n, rwse_k), gsn_feats(ei, n), sub_feats(ei, n)], dim=1)
+ return None
+ for g in graphs:
+ e = feat(g)
+ g['xin'] = torch.cat([g['rfeat'], e], dim=1) if e is not None else g['rfeat']
+ return graphs
+
+
+def deg_hist(graphs):
+ md = 0; ds = []
+ for g in graphs:
+ d = _pyg_degree(g['edge_index'][1], g['n'], dtype=torch.long)
+ md = max(md, int(d.max()) if d.numel() else 0); ds.append(d)
+ h = torch.zeros(md + 1, dtype=torch.long)
+ for d in ds:
+ h += torch.bincount(d, minlength=md + 1)
+ return h
+
+
+def make_conv(conv, hidden, deg=None):
+ if conv == 'gin':
+ return GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ if conv == 'gcn':
+ return GCNConv(hidden, hidden)
+ if conv == 'sage':
+ return SAGEConv(hidden, hidden)
+ if conv == 'gat':
+ return GATConv(hidden, hidden, heads=4, concat=False, add_self_loops=True)
+ if conv == 'pna':
+ return PNAConv(hidden, hidden, aggregators=['mean', 'min', 'max', 'std'],
+ scalers=['identity', 'amplification', 'attenuation'], deg=deg, towers=1)
+ if conv == 'gps':
+ local = GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ return GPSConv(hidden, local, heads=4)
+ raise ValueError(conv)
+
+
+class RecGINColor(nn.Module):
+ def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full', sigma=0.0, conv='gin', deg=None):
+ super().__init__()
+ self.conv_type = conv
+ self.lin_in = nn.Linear(in_dim, hidden)
+ self.convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(inner)])
+ self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)])
+ self.head = nn.Linear(hidden, k)
+ self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma
+
+ def block(self, z, ei, batch=None):
+ if self.conv_type == 'gps' and batch is None:
+ batch = z.new_zeros(z.size(0), dtype=torch.long)
+ for conv, bn in zip(self.convs, self.bns):
+ z = conv(z, ei, batch) if self.conv_type == 'gps' else conv(z, ei)
+ z = bn(z).relu()
+ return z
+
+ def _inner(self, z, h0, ei, noise, batch):
+ z = self.block(z + h0, ei, batch)
+ if noise and self.sigma > 0:
+ z = z + self.sigma * torch.randn_like(z)
+ return z
+
+ def recurse(self, z, h0, ei, noise, batch, one_step=False):
+ if one_step:
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._inner(z, h0, ei, noise, batch)
+ z = z.detach()
+ return self._inner(z, h0, ei, noise, batch)
+ for _ in range(self.T):
+ z = self._inner(z, h0, ei, noise, batch)
+ return z
+
+ def forward(self, xin, ei, batch=None, noise=False):
+ h0 = self.lin_in(xin)
+ z = torch.zeros_like(h0)
+ outs = []
+ for s in range(self.n_sup):
+ z = self.recurse(z, h0, ei, noise, batch, one_step=(self.grad_mode == '1step'))
+ outs.append(self.head(z))
+ z = z.detach()
+ return outs
+
+
+def conflict_loss(logits, ei):
+ p = F.softmax(logits, dim=-1)
+ return (p[ei[0]] * p[ei[1]]).sum(-1).mean()
+
+
+@torch.no_grad()
+def solve_stats(model, recs, dev, sample=None):
+ model.eval()
+ solved = 0; conf = 0.0; tot = 0
+ for r in (recs[:sample] if sample else recs):
+ ei = r['edge_index'].to(dev)
+ col = model(r['xin'].to(dev), ei)[-1].argmax(-1)
+ c = (col[ei[0]] == col[ei[1]]).sum().item() // 2
+ solved += int(c == 0); conf += c; tot += 1
+ return solved / tot, conf / tot
+
+
+def lyap1(model, xin, ei, n_steps, dev, seed=0):
+ g = torch.Generator(device=dev).manual_seed(seed)
+ h0 = model.lin_in(xin).detach()
+ z = torch.zeros_like(h0)
+ v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
+ def step_fn(zz):
+ return model.block(zz + h0, ei)
+ lam = 0.0
+ for _ in range(n_steps):
+ z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
+ z = z_next.detach(); nv = Jv.norm()
+ lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
+ return lam / n_steps
+
+
+def run_le(model, recs, dev, n_steps, n_graphs=300):
+ try:
+ from sklearn.metrics import roc_auc_score
+ except Exception:
+ roc_auc_score = None
+ model.eval()
+ lams, fails = [], []
+ for i, r in enumerate(recs[:n_graphs]):
+ ei = r['edge_index'].to(dev); xin = r['xin'].to(dev)
+ with torch.no_grad():
+ col = model(xin, ei)[-1].argmax(-1)
+ c = (col[ei[0]] == col[ei[1]]).sum().item()
+ fails.append(int(c > 0))
+ lams.append(lyap1(model, xin, ei, n_steps, dev, seed=i))
+ lams, fails = np.array(lams), np.array(fails)
+ s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ sep = (f.mean() - s.mean()) if len(s) and len(f) else float('nan')
+ print(f"[{model.conv_type}/{model.grad_mode}] LE n={len(lams)} fail={fails.mean():.2f} | "
+ f"SOLVED {s.mean() if len(s) else float('nan'):+.3f} UNSOLVED {f.mean() if len(f) else float('nan'):+.3f}"
+ f" sep={sep:+.3f} AUROC={auc:.3f} mean_lam={lams.mean():+.3f}")
+ return {'n': int(len(lams)), 'fail_rate': float(fails.mean()), 'auroc': float(auc), 'sep': float(sep),
+ 'lam_solved': (float(s.mean()) if len(s) else None),
+ 'lam_unsolved': (float(f.mean()) if len(f) else None), 'mean_lam': float(lams.mean())}
+
+
+def lyap_penalty(model, x, ei, batch, target=-0.5):
+ h0 = model.lin_in(x)
+ with torch.no_grad():
+ zr = model.recurse(torch.zeros_like(h0), h0.detach(), ei, False, batch)
+ v = torch.randn_like(zr); v = v / (v.norm() + 1e-12)
+ _, Jv = torch.autograd.functional.jvp(lambda zz: model.block(zz + h0, ei, batch), zr, v, create_graph=True)
+ return (torch.log(Jv.norm() + 1e-12) - target) ** 2
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--mode', choices=['train', 'le'], default='train')
+ ap.add_argument('--conv', choices=['gin', 'gcn', 'sage', 'gat', 'pna', 'gps'], default='gin')
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--pe', choices=['none', 'rwse', 'gsn', 'sub', 'lappe', 'all'], default='none')
+ ap.add_argument('--contract', action='store_true')
+ ap.add_argument('--rwse_k', type=int, default=16)
+ ap.add_argument('--ckpt', default=None)
+ ap.add_argument('--n', type=int, default=50); ap.add_argument('--k', type=int, default=3)
+ ap.add_argument('--p', type=float, default=0.2); ap.add_argument('--r', type=int, default=8)
+ ap.add_argument('--hidden', type=int, default=128); ap.add_argument('--T', type=int, default=3)
+ ap.add_argument('--n_sup', type=int, default=3); ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--bs', type=int, default=32)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ if args.mode == 'le':
+ ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg']
+ te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000),
+ c.get('pe', 'none'), c.get('rwse_k', 16))
+ deg = torch.tensor(c['deg']) if c.get('deg') else None
+ model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'],
+ grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev)
+ model.load_state_dict(ck['state']); model.eval()
+ res = run_le(model, te, dev, c['n_sup'] * c['T'])
+ base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
+ with open(os.path.join(OUT, f"le_{base}.json"), 'w') as fjs:
+ json.dump({'conv': c.get('conv', 'gin'), 'grad_mode': c['grad_mode'], 'pe': c.get('pe', 'none'),
+ 'contract': c.get('contract', False), 'seed': c.get('seed'), **res}, fjs, indent=2)
+ return
+
+ te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), args.pe, args.rwse_k)
+ tr = featurize(make_split('train', args.n, args.k, args.p, args.r, 2000, 0), args.pe, args.rwse_k)
+ in_dim = tr[0]['xin'].shape[1]
+ data = [Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr]
+ trl = DataLoader(data, batch_size=args.bs, shuffle=True, drop_last=True)
+ deg = deg_hist(tr) if args.conv == 'pna' else None
+ model = RecGINColor(in_dim, args.hidden, args.k, args.T, args.n_sup,
+ grad_mode=args.grad_mode, conv=args.conv, deg=deg).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ t0 = time.time(); best_solve = -1; best = {}; best_state = None
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ outs = model(b.x, b.edge_index, b.batch, noise=False)
+ loss = sum(conflict_loss(o, b.edge_index) for o in outs) / len(outs)
+ if args.contract:
+ loss = loss + lyap_penalty(model, b.x, b.edge_index, b.batch)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ if torch.is_floating_point(v):
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001)
+ else:
+ ema[kk].copy_(v.detach())
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ backup = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema)
+ sr, mc = solve_stats(model, te, dev, sample=300)
+ if sr > best_solve:
+ best_solve = sr; best = {'ep': ep + 1, 'solve_rate': round(sr, 4), 'mean_conflicts': round(mc, 3)}
+ best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(backup)
+ print(f"ep{ep+1} solve_rate={sr:.3f} mean_conflicts={mc:.2f}", flush=True)
+
+ sfx = ('_ctr' if args.contract else '')
+ tag = f"color_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}"
+ rep = {'task': 'graph3coloring', 'tag': tag, **vars(args), 'in_dim': in_dim,
+ 'sec': round(time.time() - t0, 1), **best}
+ print(f"[{tag}] best solve_rate={best.get('solve_rate')} @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ torch.save({'state': best_state, 'cfg': {'in_dim': in_dim, 'hidden': args.hidden, 'k': args.k,
+ 'T': args.T, 'n_sup': args.n_sup, 'grad_mode': args.grad_mode, 'pe': args.pe,
+ 'rwse_k': args.rwse_k, 'contract': args.contract, 'conv': args.conv, 'seed': args.seed,
+ 'deg': (deg.tolist() if deg is not None else None)}},
+ os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/train_cycle.py b/diag/train_cycle.py
new file mode 100644
index 0000000..d2342f3
--- /dev/null
+++ b/diag/train_cycle.py
@@ -0,0 +1,183 @@
+"""GNN-native ring-counting on real molecules (ZINC): regress [#5-cycles, #6-cycles].
+
+k-cycle counts (k>=3) are provably NOT computable by 1-WL/MPNN (Chen et al. 2020) -> a REAL
+H2 ceiling on REAL graphs. Training-based diagnosis (partition instrument is vacuous on
+feature-rich graphs):
+ GIN(L) 1-WL baseline -> should FAIL to count
+ GCN(L) sub-1-WL reference
+ GIN+RNI random feats = NOISE -> PTRM-style crude symmetry break (eval-averaged)
+ GIN+RWSE random-walk return probs-> structured >1-WL positive control
+Reads: GIN high error + RWSE fixes it = real ceiling exists; RNI also fixes = crude noise
+breaks it (bridge cashed); only RWSE = bridge needs STRUCTURED stochasticity (GRAM>PTRM).
+Targets z-scored for training; per-target MAE reported in RAW ring units.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_cycle.py --conv gin --feat none
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import networkx as nx
+from torch_geometric.datasets import ZINC
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from torch_geometric.utils import to_networkx
+from torch_geometric.nn import GINConv, GCNConv, global_add_pool
+
+ROOT = '/home/yurenh2/rrog/data/zinc'
+CACHE = '/home/yurenh2/rrog/data/cycle_cache'
+OUT = '/home/yurenh2/rrog/runs'
+RWSE_K = 16
+
+
+def rwse(edge_index, n, K=RWSE_K):
+ A = np.zeros((n, n), dtype=np.float64)
+ ei = edge_index.numpy()
+ A[ei[0], ei[1]] = 1.0
+ A = np.maximum(A, A.T)
+ deg = A.sum(1)
+ P = A / np.where(deg > 0, deg, 1.0)[:, None]
+ out = np.zeros((n, K), dtype=np.float32)
+ M = np.eye(n)
+ for k in range(K):
+ M = M @ P
+ out[:, k] = np.diag(M)
+ return torch.from_numpy(out)
+
+
+def c56(data):
+ G = to_networkx(data, to_undirected=True)
+ c = {5: 0, 6: 0}
+ for cyc in nx.simple_cycles(G, length_bound=6):
+ L = len(cyc)
+ if L in c:
+ c[L] += 1
+ return [float(c[5]), float(c[6])]
+
+
+def prepare(split):
+ os.makedirs(CACHE, exist_ok=True)
+ fp = os.path.join(CACHE, f"{split}.pt")
+ if os.path.exists(fp):
+ return torch.load(fp, weights_only=False)
+ ds = ZINC(ROOT, subset=True, split=split)
+ out = []
+ for g in ds:
+ out.append({'x': g.x.view(-1).long(), 'edge_index': g.edge_index,
+ 'rwse': rwse(g.edge_index, g.num_nodes),
+ 'y': torch.tensor(c56(g), dtype=torch.float)})
+ torch.save(out, fp)
+ return out
+
+
+class Net(nn.Module):
+ def __init__(self, n_atom, hidden, layers, conv='gin', rni=0, use_rwse=False):
+ super().__init__()
+ self.emb = nn.Embedding(n_atom, hidden)
+ self.rni, self.use_rwse = rni, use_rwse
+ din = hidden + rni + (RWSE_K if use_rwse else 0)
+ self.lin_in = nn.Linear(din, hidden)
+ self.convs, self.bns = nn.ModuleList(), nn.ModuleList()
+ for _ in range(layers):
+ if conv == 'gin':
+ self.convs.append(GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True))
+ else:
+ self.convs.append(GCNConv(hidden, hidden))
+ self.bns.append(nn.BatchNorm1d(hidden))
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
+
+ def forward(self, x, edge_index, batch, rwse=None):
+ h = self.emb(x)
+ parts = [h]
+ if self.use_rwse:
+ parts.append(rwse)
+ if self.rni:
+ parts.append(torch.randn(h.size(0), self.rni, device=h.device))
+ h = self.lin_in(torch.cat(parts, dim=1))
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, edge_index)).relu()
+ return self.head(global_add_pool(h, batch))
+
+
+def to_loader(recs, bs, shuffle, drop_last=False):
+ data = [Data(x=r['x'], edge_index=r['edge_index'], rwse=r['rwse'],
+ y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs]
+ return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+
+
+@torch.no_grad()
+def eval_mae(model, loader, dev, ymu, ysd, samples=1):
+ model.eval(); abs_err = torch.zeros(2); n = 0
+ for b in loader:
+ b = b.to(dev)
+ ps = torch.stack([model(b.x, b.edge_index, b.batch, b.rwse) for _ in range(samples)]).mean(0)
+ pr = ps * ysd.to(dev) + ymu.to(dev) # un-standardize -> raw ring units
+ yr = b.y * ysd.to(dev) + ymu.to(dev)
+ abs_err += (pr - yr).abs().sum(0).cpu(); n += b.num_graphs
+ return (abs_err / n).tolist()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--feat', choices=['none', 'rni', 'rwse'], default='none')
+ ap.add_argument('--layers', type=int, default=5)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--rni_dim', type=int, default=16)
+ ap.add_argument('--epochs', type=int, default=200)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr, va, te = prepare('train'), prepare('val'), prepare('test')
+ n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
+ Ytr = torch.stack([r['y'] for r in tr])
+ ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+ for recs in (tr, va, te):
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+
+ rni = args.rni_dim if args.feat == 'rni' else 0
+ use_rwse = args.feat == 'rwse'
+ samples = 8 if rni else 1
+ model = Net(n_atom, args.hidden, args.layers, args.conv, rni, use_rwse).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ lossf = nn.L1Loss()
+ trl = to_loader(tr, args.bs, True, drop_last=True)
+ trl_e, val, tel = to_loader(tr, 256, False), to_loader(va, 256, False), to_loader(te, 256, False)
+
+ t0 = time.time(); best_val = 9e9; best = {}
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = lossf(model(b.x, b.edge_index, b.batch, b.rwse), b.y)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ vm = eval_mae(model, val, dev, ymu, ysd, samples)
+ if sum(vm) < best_val:
+ best_val = sum(vm)
+ best = {'ep': ep + 1, 'train_mae': eval_mae(model, trl_e, dev, ymu, ysd, samples),
+ 'val_mae': vm, 'test_mae': eval_mae(model, tel, dev, ymu, ysd, samples)}
+ print(f"ep{ep+1} val_mae(c5,c6)={[round(x,3) for x in vm]}", flush=True)
+
+ tag = f"{args.conv}_{args.feat}_L{args.layers}_s{args.seed}"
+ rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args),
+ 'y_std_raw': ysd.tolist(), 'sec': round(time.time() - t0, 1), 'dev': dev, **best}
+ tm = best.get('test_mae'); trm = best.get('train_mae')
+ print(f"[{tag}] train_mae(c5,c6)={[round(x,3) for x in trm]} test_mae={[round(x,3) for x in tm]} "
+ f"(raw rings; std={ [round(x,2) for x in ysd.tolist()] }) @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"cyc_{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", os.path.join(OUT, f"cyc_{tag}.json"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/train_diag.py b/diag/train_diag.py
new file mode 100644
index 0000000..a9c4ab2
--- /dev/null
+++ b/diag/train_diag.py
@@ -0,0 +1,161 @@
+"""Train a backbone, collect failures, attribute them via the 1-WL instrument.
+
+Node features are CONSTANT (all-ones) so the GIN starts from anonymous nodes -> its
+expressivity ceiling is exactly the anonymous 1-WL partition the instrument computes
+(wl_refine init = all-zero). GIN depth L == L WL rounds. Regression targets are
+standardized (train stats) for stable training; all reported MSEs are in original units.
+Train AND test metrics are reported so non-H2 error can be split into optimization
+(can't even fit train) vs generalization (fits train, fails test).
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_diag.py --task csl --model gin
+"""
+import argparse, json, os, time
+from collections import Counter
+import numpy as np
+import torch
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from diag import wl, datasets as DS, models as M
+
+
+def to_pyg(raw, task, ymu=0.0, ysd=1.0):
+ out = []
+ for d in raw:
+ x = torch.ones(d['n'], 1)
+ ei = torch.tensor(d['edge_index'], dtype=torch.long)
+ if task == 'clf':
+ y = torch.tensor([d['y']], dtype=torch.long)
+ else:
+ y = torch.tensor([[(d['y'] - ymu) / ysd]], dtype=torch.float)
+ out.append(Data(x=x, edge_index=ei, y=y, num_nodes=d['n']))
+ return out
+
+
+def split(n, frac, seed, y=None, stratify=False):
+ rng = np.random.default_rng(seed)
+ idx = np.arange(n)
+ if stratify and y is not None:
+ y = np.asarray(y); test = []
+ for c in np.unique(y):
+ ci = idx[y == c]; rng.shuffle(ci)
+ test += ci[:max(1, int(round(frac * len(ci))))].tolist()
+ test = sorted(set(test)); train = [i for i in idx.tolist() if i not in set(test)]
+ else:
+ rng.shuffle(idx); k = int(frac * n)
+ test = sorted(idx[:k].tolist()); train = sorted(idx[k:].tolist())
+ return train, test
+
+
+@torch.no_grad()
+def predict(model, loader, task, dev):
+ model.eval(); outs = []
+ for b in loader:
+ b = b.to(dev)
+ o = model(b.x, b.edge_index, b.batch)
+ outs.append((o.argmax(1) if task == 'clf' else o.view(-1)).cpu())
+ return torch.cat(outs).numpy()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--task', choices=['csl', 'tri'], required=True)
+ ap.add_argument('--model', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--layers', type=int, default=4)
+ ap.add_argument('--hidden', type=int, default=64)
+ ap.add_argument('--epochs', type=int, default=300)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--seed', type=int, default=0)
+ ap.add_argument('--kind', default='er')
+ ap.add_argument('--out', default='/home/yurenh2/rrog/runs')
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(args.out, exist_ok=True)
+
+ if args.task == 'csl':
+ raw = DS.build_csl(n_per_class=15, seed=args.seed); task, out_dim = 'clf', 10
+ y = [d['y'] for d in raw]; tr, te = split(len(raw), 0.34, args.seed, y, stratify=True)
+ else:
+ raw = DS.build_triangle_count(n_graphs=800, n_nodes=18, kind=args.kind, deg=3, seed=args.seed)
+ task, out_dim = 'reg', 1; tr, te = split(len(raw), 0.3, args.seed)
+
+ ymu, ysd = 0.0, 1.0
+ if task == 'reg':
+ ytr = np.array([raw[i]['y'] for i in tr], dtype=np.float64)
+ ymu, ysd = float(ytr.mean()), float(ytr.std() + 1e-8)
+
+ pyg = to_pyg(raw, task, ymu, ysd)
+ trl = DataLoader([pyg[i] for i in tr], batch_size=32, shuffle=True, drop_last=True)
+ alll = DataLoader(pyg, batch_size=64)
+
+ Model = M.GIN if args.model == 'gin' else M.GCN
+ model = Model(in_dim=1, hidden=args.hidden, layers=args.layers, out_dim=out_dim).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3)
+ lossf = torch.nn.CrossEntropyLoss() if task == 'clf' else torch.nn.MSELoss()
+
+ t0 = time.time()
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ o = model(b.x, b.edge_index, b.batch)
+ loss = lossf(o, b.y)
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ pred = predict(model, alll, task, dev)
+ if task == 'reg':
+ pred = pred * ysd + ymu
+ yv = np.array([d['y'] for d in raw], dtype=(np.float64 if task == 'reg' else np.int64))
+ adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in raw]
+ _, ghist, conv = wl.wl_refine(adjs)
+
+ rep = {'task': args.task, 'model': args.model, 'layers': args.layers, 'seed': args.seed,
+ 'kind': (args.kind if args.task == 'tri' else None),
+ 'n': len(raw), 'n_train': len(tr), 'n_test': len(te), 'conv_round': conv,
+ 'sec': round(time.time() - t0, 1), 'dev': dev}
+
+ if task == 'clf':
+ test_pred, test_y = pred[te], yv[te]
+ acc = float((test_pred == test_y).mean())
+ train_acc = float((pred[tr] == yv[tr]).mean())
+ att = wl.attribute_classification(ghist, conv, args.layers, yv, tr, te)
+ fails = [te[k] for k in range(len(te)) if test_pred[k] != test_y[k]]
+ fb = Counter(att['buckets'][i] for i in fails)
+ rep.update({'train_acc': round(train_acc, 4), 'test_acc': round(acc, 4),
+ 'wl_ceiling_acc_converged': round(att['wl_optimal_acc_converged'], 4),
+ 'wl_ceiling_acc_Ldepth': round(att['wl_optimal_acc_Ldepth'], 4),
+ 'test_bucket_counts': att['counts'],
+ 'failure_bucket_counts': dict(fb), 'n_failures': len(fails)})
+ print(f"[{args.task}/{args.model}] train_acc={train_acc:.3f} test_acc={acc:.3f} | "
+ f"1-WL ceiling(conv)={att['wl_optimal_acc_converged']:.3f} "
+ f"L-depth={att['wl_optimal_acc_Ldepth']:.3f} | failures={len(fails)} -> {dict(fb)}")
+ else:
+ test_pred, test_y = pred[te], yv[te]
+ mse = float(((test_pred - test_y) ** 2).mean())
+ train_mse = float(((pred[tr] - yv[tr]) ** 2).mean())
+ dec = wl.decompose_regression(ghist, conv, args.layers, yv, tr, te)
+ h2 = dec['mse_floor_oracle_H2']
+ rep.update({'train_mse': round(train_mse, 4), 'test_mse_gin': round(mse, 4),
+ 'mse_floor_oracle_H2': round(h2, 4),
+ 'mse_floor_converged_train': round(dec['mse_floor_converged_train'], 4),
+ 'mse_floor_Ldepth_train': round(dec['mse_floor_Ldepth_train'], 4),
+ 'var_target_test': round(dec['var_target_eval'], 4),
+ 'frac_test_unseen_color': round(dec['frac_test_unseen_color'], 4),
+ 'frac_test_singleton_color': round(dec['frac_test_singleton_color'], 4),
+ 'learn_gap_test': round(max(0.0, mse - h2), 4)})
+ print(f"[{args.task}/{args.model}/{args.kind}] train_mse={train_mse:.3f} test_mse={mse:.3f} | "
+ f"1-WL oracle floor(H2)={h2:.3f} | unseen={dec['frac_test_unseen_color']:.2f} "
+ f"singleton={dec['frac_test_singleton_color']:.2f} | learn_gap={max(0.0, mse - h2):.3f} "
+ f"var_y={dec['var_target_eval']:.3f}")
+
+ tag = f"{args.task}_{args.kind}" if args.task == 'tri' else args.task
+ fn = os.path.join(args.out, f"diag_{tag}_{args.model}_L{args.layers}_s{args.seed}.json")
+ with open(fn, 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", fn)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/train_real.py b/diag/train_real.py
new file mode 100644
index 0000000..336d86f
--- /dev/null
+++ b/diag/train_real.py
@@ -0,0 +1,139 @@
+"""Training-based failure diagnosis on LRGB Peptides-struct (real, large, long-range).
+
+The WL partition instrument is vacuous here (graphs ~all distinguishable), so we diagnose
+by TRAINING and comparing:
+ GIN(L) : standard 1-WL backbone at depth L
+ GIN(L)+RNI : random node features = noise = beyond-1-WL symmetry breaker
+ GCN(L) : sub-1-WL reference
+Reads: deeper helps -> long-range/under-reaching; RNI helps -> a real >1-WL ceiling that
+noise breaks; train<<test -> generalization; train high -> compute/optimization ceiling.
+Targets z-scored per dim; metric = standardized MAE (lower better). 11 targets.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_real.py --conv gin --layers 5 --rni 0
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+from torch_geometric.datasets import LRGBDataset
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import GINConv, GCNConv, global_mean_pool
+
+ROOT = '/home/yurenh2/rrog/data/lrgb'
+OUT = '/home/yurenh2/rrog/runs'
+
+
+class Net(nn.Module):
+ def __init__(self, col_sizes, hidden, layers, out_dim, conv='gin', rni=0):
+ super().__init__()
+ self.embs = nn.ModuleList([nn.Embedding(int(s), hidden) for s in col_sizes])
+ self.rni = rni
+ self.lin_in = nn.Linear(hidden + rni, hidden)
+ self.convs, self.bns = nn.ModuleList(), nn.ModuleList()
+ for _ in range(layers):
+ if conv == 'gin':
+ mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden))
+ self.convs.append(GINConv(mlp, train_eps=True))
+ else:
+ self.convs.append(GCNConv(hidden, hidden))
+ self.bns.append(nn.BatchNorm1d(hidden))
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
+
+ def forward(self, x, edge_index, batch):
+ h = sum(emb(x[:, i]) for i, emb in enumerate(self.embs))
+ if self.rni:
+ h = torch.cat([h, torch.randn(h.size(0), self.rni, device=h.device)], dim=1)
+ h = self.lin_in(h)
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, edge_index)).relu()
+ return self.head(global_mean_pool(h, batch))
+
+
+@torch.no_grad()
+def mae(model, loader, dev, ymu, ysd):
+ model.eval(); se = n = 0.0
+ for b in loader:
+ b = b.to(dev)
+ o = model(b.x, b.edge_index, b.batch)
+ se += (o - b.y).abs().sum().item(); n += b.y.numel()
+ return se / n
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--layers', type=int, default=5)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--rni', type=int, default=0)
+ ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr = LRGBDataset(root=ROOT, name='Peptides-struct', split='train')
+ va = LRGBDataset(root=ROOT, name='Peptides-struct', split='val')
+ te = LRGBDataset(root=ROOT, name='Peptides-struct', split='test')
+
+ # per-column embedding sizes + target standardization (train stats)
+ col_max = None
+ Ytr = []
+ for g in tr:
+ m = g.x.max(0).values
+ col_max = m if col_max is None else torch.maximum(col_max, m)
+ Ytr.append(g.y.view(-1))
+ for ds in (va, te):
+ for g in ds:
+ col_max = torch.maximum(col_max, g.x.max(0).values)
+ col_sizes = (col_max + 2).tolist()
+ Ytr = torch.stack(Ytr)
+ ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+
+ def norm(ds):
+ out = []
+ for g in ds:
+ g = g.clone(); g.y = (g.y.view(1, -1) - ymu) / ysd
+ out.append(g)
+ return out
+ trl = DataLoader(norm(tr), batch_size=args.bs, shuffle=True, drop_last=True)
+ val = DataLoader(norm(va), batch_size=256)
+ tel = DataLoader(norm(te), batch_size=256)
+ trl_eval = DataLoader(norm(tr), batch_size=256)
+
+ model = Net(col_sizes, args.hidden, args.layers, out_dim=11, conv=args.conv, rni=args.rni).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ lossf = nn.L1Loss()
+
+ t0 = time.time(); best_val = 9e9; best = {}
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = lossf(model(b.x, b.edge_index, b.batch), b.y)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 15 == 0 or ep == args.epochs - 1:
+ vm = mae(model, val, dev, ymu, ysd)
+ if vm < best_val:
+ best_val = vm
+ best = {'ep': ep + 1, 'train_mae': mae(model, trl_eval, dev, ymu, ysd),
+ 'val_mae': vm, 'test_mae': mae(model, tel, dev, ymu, ysd)}
+ print(f"ep{ep+1} val_mae={vm:.4f}", flush=True)
+
+ tag = f"{args.conv}_L{args.layers}_rni{args.rni}_s{args.seed}"
+ rep = {'dataset': 'Peptides-struct', 'tag': tag, **vars(args),
+ 'sec': round(time.time() - t0, 1), 'dev': dev, **best}
+ print(f"[{tag}] train_mae={best.get('train_mae'):.4f} val_mae={best.get('val_mae'):.4f} "
+ f"test_mae={best.get('test_mae'):.4f} @ep{best.get('ep')} ({rep['sec']}s)")
+ fn = os.path.join(OUT, f"real_{tag}.json")
+ with open(fn, 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", fn)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/train_rec.py b/diag/train_rec.py
new file mode 100644
index 0000000..9866f28
--- /dev/null
+++ b/diag/train_rec.py
@@ -0,0 +1,174 @@
+"""Step-2: Recursive (TRM-ish) GNN on ZINC ring-counting + optional PTRM noise/selection.
+
+Recurrent shared-weight GIN block, deep-supervised over n_sup steps (TRM-style: carry latent
+detached between steps). --grad_mode controls the LAST supervision step's recursion:
+ full : backprop through all T inner recursions (TRM)
+ 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient)
+Optional per-step Gaussian noise (sigma) + K stochastic rollouts selected by a value head
+(best-Q@K) for the PTRM experiments. Saves a checkpoint for the LE diagnostic (diag/lyap.py).
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+from torch_geometric.loader import DataLoader
+from torch_geometric.data import Data
+from torch_geometric.nn import GINConv, global_add_pool
+from diag.train_cycle import prepare
+
+OUT = '/home/yurenh2/rrog/runs'
+
+
+def loader(recs, bs, shuffle, drop_last=False):
+ data = [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2),
+ num_nodes=r['x'].numel()) for r in recs]
+ return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+
+
+class RecGIN(nn.Module):
+ def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, grad_mode='full'):
+ super().__init__()
+ self.emb = nn.Embedding(n_atom, hidden)
+ self.convs = nn.ModuleList([GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ for _ in range(inner)])
+ self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)])
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
+ self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
+ self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode
+
+ def block(self, z, ei):
+ for conv, bn in zip(self.convs, self.bns):
+ z = bn(conv(z, ei)).relu()
+ return z
+
+ def _inner(self, z, h0, ei, noise):
+ z = self.block(z + h0, ei)
+ if noise and self.sigma > 0:
+ z = z + self.sigma * torch.randn_like(z)
+ return z
+
+ def recurse(self, z, h0, ei, noise, one_step=False):
+ if one_step: # HRM 1-step gradient
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._inner(z, h0, ei, noise)
+ z = z.detach()
+ return self._inner(z, h0, ei, noise) # only last inner carries grad
+ for _ in range(self.T): # TRM full recursion
+ z = self._inner(z, h0, ei, noise)
+ return z
+
+ def forward(self, x, ei, batch, noise=False):
+ h0 = self.emb(x)
+ z = torch.zeros_like(h0)
+ preds = []
+ for s in range(self.n_sup):
+ if s < self.n_sup - 1:
+ with torch.no_grad():
+ z = self.recurse(z, h0, ei, noise)
+ z = z.detach()
+ else:
+ z = self.recurse(z, h0, ei, noise, one_step=(self.grad_mode == '1step'))
+ preds.append(self.head(global_add_pool(z, batch)))
+ q = self.qhead(global_add_pool(z, batch)).view(-1)
+ return preds, q
+
+
+@torch.no_grad()
+def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'):
+ model.eval()
+ ysd_d, ymu_d = ysd.to(dev), ymu.to(dev)
+ ae = torch.zeros(2); ae_or = torch.zeros(2); n = 0
+ for b in ld:
+ b = b.to(dev)
+ if K == 1:
+ preds, _ = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ chosen = oracle = preds[-1]
+ else:
+ P, Q = [], []
+ for _ in range(K):
+ preds, q = model(b.x, b.edge_index, b.batch, noise=True)
+ P.append(preds[-1]); Q.append(q)
+ P = torch.stack(P); Q = torch.stack(Q)
+ ar = torch.arange(P.size(1), device=dev)
+ chosen = P[Q.argmax(0), ar] if select == 'bestq' else P.mean(0)
+ oracle = P[(P - b.y.unsqueeze(0)).abs().sum(-1).argmin(0), ar]
+ ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ ae_or += ((oracle * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ n += b.num_graphs
+ return (ae / n).tolist(), (ae_or / n).tolist()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--sigma', type=float, default=0.0)
+ ap.add_argument('--K', type=int, default=1)
+ ap.add_argument('--select', choices=['none', 'bestq'], default='bestq')
+ ap.add_argument('--T', type=int, default=3)
+ ap.add_argument('--n_sup', type=int, default=3)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--epochs', type=int, default=200)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--lam_q', type=float, default=1.0)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr, va, te = prepare('train'), prepare('val'), prepare('test')
+ n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
+ Ytr = torch.stack([r['y'] for r in tr]); ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+ for recs in (tr, va, te):
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+ trl = loader(tr, args.bs, True, drop_last=True)
+ val, tel = loader(va, 256, False), loader(te, 256, False)
+
+ model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ l1 = nn.L1Loss()
+
+ t0 = time.time(); best_val = 9e9; best = {}; best_state = None
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ loss = sum(l1(p, b.y) for p in preds) / len(preds)
+ with torch.no_grad():
+ tq = -(preds[-1] - b.y).abs().mean(1)
+ loss = loss + args.lam_q * nn.functional.mse_loss(q, tq)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select)
+ if sum(vm) < best_val:
+ best_val = sum(vm)
+ tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select)
+ best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo}
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
+ print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True)
+
+ tag = f"rec_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}_s{args.seed}"
+ rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1),
+ 'dev': dev, 'y_std_raw': ysd.tolist(), **best}
+ print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} "
+ f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ torch.save({'state': best_state or model.state_dict(),
+ 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup,
+ 'sigma': args.sigma, 'grad_mode': args.grad_mode},
+ 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/wl.py b/diag/wl.py
new file mode 100644
index 0000000..26ab0a3
--- /dev/null
+++ b/diag/wl.py
@@ -0,0 +1,166 @@
+"""1-WL color-refinement instrument for diagnosing GNN failures (H1 vs H2).
+
+A GIN with L layers == L rounds of 1-WL refinement (injective sum aggregation).
+A failure on sample i is attributed by label purity of its WL color classes:
+
+ converged-WL class IMPURE (train labels conflict under same color)
+ -> H2 : 1-WL ceiling. No MPNN at ANY depth separates -> needs >1-WL (noise).
+ converged pure, but L-round class impure
+ -> H1a_depth : separable only with MORE rounds -> deterministic RR-on-graph / depth helps.
+ L-round class pure (info present at depth L) but model wrong
+ -> H1b_opt : optimization / capacity. Train better.
+
+Refinement is dataset-global (shared per-round signature->label map) so node colors and
+graph-color histograms are comparable across graphs.
+"""
+from collections import Counter, defaultdict
+import numpy as np
+
+
+def edges_to_adj(n, edge_index):
+ adj = [[] for _ in range(n)]
+ ei = np.asarray(edge_index)
+ for a, b in zip(ei[0].tolist(), ei[1].tolist()):
+ adj[a].append(b)
+ return adj
+
+
+def wl_refine(adjs, inits=None, max_rounds=None):
+ """Dataset-level 1-WL. Returns (node_rounds, ghist_rounds, conv_round).
+ node_rounds[r][g] = int color array (global labels) of graph g after r rounds.
+ ghist_rounds[r][g] = canonical color histogram (hashable) of graph g after r rounds.
+ conv_round = round index at which the global partition stabilized.
+ """
+ if inits is None:
+ inits = [np.zeros(len(a), dtype=np.int64) for a in adjs]
+ else:
+ inits = [np.asarray(x, dtype=np.int64) for x in inits]
+ if max_rounds is None:
+ max_rounds = max((len(a) for a in adjs), default=0) + 2
+
+ d = {}
+ def lab(s):
+ v = d.get(s)
+ if v is None:
+ v = len(d); d[s] = v
+ return v
+
+ cur = [np.array([lab(('i', int(c))) for c in init], dtype=np.int64) for init in inits]
+ node_rounds = [cur]
+ nclasses = [len(d)]
+
+ for _r in range(max_rounds):
+ d = {}
+ nxt = []
+ for adj in adjs:
+ c = cur_g = node_rounds[-1][len(nxt)]
+ arr = np.empty(len(adj), dtype=np.int64)
+ for v in range(len(adj)):
+ sig = (int(c[v]), tuple(sorted(int(c[u]) for u in adj[v])))
+ arr[v] = lab(sig)
+ nxt.append(arr)
+ node_rounds.append(nxt)
+ nclasses.append(len(d))
+ if nclasses[-1] == nclasses[-2]: # global #classes stopped growing -> converged
+ break
+
+ conv_round = len(node_rounds) - 1
+ ghist_rounds = [[_hist(c) for c in nr] for nr in node_rounds]
+ return node_rounds, ghist_rounds, conv_round
+
+
+def _hist(colors):
+ return tuple(sorted(Counter(colors.tolist()).items()))
+
+
+def graph_colors_at(ghist_rounds, conv_round, L):
+ return ghist_rounds[min(L, conv_round)]
+
+
+# ---------- classification attribution ----------
+def attribute_classification(ghist_rounds, conv_round, L, y, train_idx, eval_idx):
+ y = np.asarray(y)
+ conv = ghist_rounds[conv_round]
+ Lr = min(L, conv_round)
+ lr = ghist_rounds[Lr]
+ conv_train, lr_train = defaultdict(list), defaultdict(list)
+ for i in train_idx:
+ conv_train[conv[i]].append(int(y[i]))
+ lr_train[lr[i]].append(int(y[i]))
+
+ def pure(dct, key):
+ labs = dct.get(key)
+ return labs is not None and len(set(labs)) == 1
+
+ def majority(dct, key):
+ labs = dct.get(key)
+ return Counter(labs).most_common(1)[0][0] if labs else None
+
+ buckets = {}
+ wl_opt = lr_opt = 0
+ for i in eval_idx:
+ if conv[i] not in conv_train:
+ buckets[i] = 'novel'
+ elif not pure(conv_train, conv[i]):
+ buckets[i] = 'H2'
+ elif not pure(lr_train, lr[i]):
+ buckets[i] = 'H1a_depth'
+ else:
+ buckets[i] = 'H1b_opt'
+ if majority(conv_train, conv[i]) == int(y[i]):
+ wl_opt += 1
+ if majority(lr_train, lr[i]) == int(y[i]):
+ lr_opt += 1
+ n = len(eval_idx)
+ return {
+ 'buckets': buckets,
+ 'counts': dict(Counter(buckets.values())),
+ 'wl_optimal_acc_converged': wl_opt / n, # best ANY MPNN can do
+ 'wl_optimal_acc_Ldepth': lr_opt / n, # best L-layer MPNN can do
+ 'L_used': Lr, 'conv_round': conv_round,
+ }
+
+
+# ---------- regression decomposition ----------
+def decompose_regression(ghist_rounds, conv_round, L, y, train_idx, eval_idx):
+ """H2 floor = ORACLE within-color variance on FULL data (best possible function of the WL
+ color: do same-color graphs share the target?). This is the true information ceiling and is
+ NOT confounded by train/test coverage. The train-fitted floors are also reported to expose
+ how much apparent error is really novel-color generalization, plus coverage fractions."""
+ y = np.asarray(y, dtype=np.float64)
+ conv = ghist_rounds[conv_round]
+ Lr = min(L, conv_round)
+ lr = ghist_rounds[Lr]
+ full_idx = list(range(len(y)))
+
+ # oracle: best constant per converged color over ALL data -> irreducible by any MPNN
+ conv_mean_full = _group_mean(conv, y, full_idx)
+ e_oracle = np.array([conv_mean_full[conv[i]] - y[i] for i in eval_idx])
+
+ # train-fitted (achievable with this split); fallback to global mean on unseen colors
+ conv_mean_tr = _group_mean(conv, y, train_idx)
+ lr_mean_tr = _group_mean(lr, y, train_idx)
+ gmean = float(y[list(train_idx)].mean())
+ e_conv_tr = np.array([conv_mean_tr.get(conv[i], gmean) - y[i] for i in eval_idx])
+ e_lr_tr = np.array([lr_mean_tr.get(lr[i], gmean) - y[i] for i in eval_idx])
+
+ conv_count = Counter(conv[i] for i in full_idx)
+ train_colors = set(conv[i] for i in train_idx)
+ frac_unseen = float(np.mean([conv[i] not in train_colors for i in eval_idx]))
+ frac_singleton = float(np.mean([conv_count[conv[i]] == 1 for i in eval_idx]))
+ return {
+ 'mse_floor_oracle_H2': float((e_oracle ** 2).mean()), # TRUE 1-WL ceiling
+ 'mse_floor_converged_train': float((e_conv_tr ** 2).mean()),
+ 'mse_floor_Ldepth_train': float((e_lr_tr ** 2).mean()),
+ 'frac_test_unseen_color': frac_unseen,
+ 'frac_test_singleton_color': frac_singleton,
+ 'L_used': Lr, 'conv_round': conv_round,
+ 'var_target_eval': float(y[list(eval_idx)].var()),
+ }
+
+
+def _group_mean(colors, y, idx):
+ acc = defaultdict(list)
+ for i in idx:
+ acc[colors[i]].append(float(y[i]))
+ return {k: float(np.mean(v)) for k, v in acc.items()}
diff --git a/papers/gram_2605.19376.pdf b/papers/gram_2605.19376.pdf
new file mode 100644
index 0000000..b6637ef
--- /dev/null
+++ b/papers/gram_2605.19376.pdf
Binary files differ
diff --git a/papers/hrm_2506.21734.pdf b/papers/hrm_2506.21734.pdf
new file mode 100644
index 0000000..9c87a62
--- /dev/null
+++ b/papers/hrm_2506.21734.pdf
Binary files differ
diff --git a/papers/ptrm_2605.19943.pdf b/papers/ptrm_2605.19943.pdf
new file mode 100644
index 0000000..ce01e6c
--- /dev/null
+++ b/papers/ptrm_2605.19943.pdf
Binary files differ
diff --git a/papers/ptrm_2605.19943.txt b/papers/ptrm_2605.19943.txt
new file mode 100644
index 0000000..49213a0
--- /dev/null
+++ b/papers/ptrm_2605.19943.txt
@@ -0,0 +1,1418 @@
+Probabilistic Tiny Recursive Model
+
+Ali Parviz
+Mila – Quebec AI Institute
+
+Alexia Jolicoeur-Martineau
+Independent
+
+{amin.sghaier, ali.parviz}@mila.quebec
+alexia.jolicoeur-martineau@mail.mcgill.ca
+
+Abstract
+Tiny Recursive Models (TRM) solve complex reasoning tasks with a fraction of
+the parameters of modern large language models (LLMs) by iteratively refining a
+latent state and final answer. While powerful, their deterministic recursion can lead
+to convergence at suboptimal solutions, without escape mechanism. A common
+workaround relies on task-specific input perturbations at test time combined with
+answer aggregation via voting. We introduce Probabilistic TRM (PTRM), a taskagnostic framework for test-time compute scaling that addresses this limitation
+through stochastic exploration. PTRM injects Gaussian noise at each deep recursion
+step, enabling parallel trajectories to explore diverse solution basins, and selects
+among them using the model’s existing Q head (used for early stopping in the
+original TRM). Without requiring retraining or task-specific augmentations, PTRM
+enables substantial accuracy gains across benchmarks, including Sudoku-Extreme
+(87.4% to 98.75%) and on various puzzles from Pencil Puzzle Bench (62.6% to
+91.2%). On the latter, PTRM achieves nearly double the accuracy of frontier LLMs
+(91.2% vs. 55.1%) at less than 0.0001x the cost, using only 7M parameters.
+
+PPBench Puzzles
+
+sudoku, lightup, nurikabe, heyawake, and tapa
+
+91.2
+
+80
+55.1
+34.7
+
+Direct prediction
+Deterministic recursive prediction
+
+0
+
+0
+
+PTRM (ours)
+
+0
+
+TRM
+
+PTRM (ours)
+
+TRM
+
+LLM ensemble
+
+claude-opus-4-6
+
+Chain-of-thought, pretrained
+LLM ensemble
+
+0
+
+HRM
+
+20
+
+Direct pred
+
+40
+0
+
+gemini-3.1-pro
+
+Direct pred
+
+0
+
+24.5
+
+98.75
+
+55
+
+60
+
+2
+gpt-5.2@xhigh
+
+20
+
+24.5
+
+80
+
+62.6
+
+o3-mini-high
+
+40
+
+87.4
+
+Claude 3.7 8K
+
+60
+
+Sudoku-Extreme
+100
+
+Deepseek R1
+
+100
+
+Accuracy (%)
+
+arXiv:2605.19943v1 [cs.AI] 19 May 2026
+
+Amin Sghaier
+Mila – Quebec AI Institute
+ILLS & ETS Montreal
+
+Probabilistic recursive prediction (ours)
+
+Best of 7 strongest LLMs. Assumes access to a perfect verifier.
+
+Figure 1: PTRM performance comparison. On various PPBench puzzles, PTRM boosts TRM
+performance by 28.6 points without any retraining. It outperforms the strongest single frontier LLMs
+by 56.5 points and an ensemble of the seven strongest LLMs (assuming a perfect verifier) by 36
+points. On Sudoku-Extreme, PTRM reaches a state of the art 98.75%.
+
+ 1
+
+Introduction
+
+Tiny Recursive Models (TRM) [1] achieve strong performance on complex reasoning puzzles with
+orders of magnitude fewer parameters than the large language models (LLMs) they outperform on
+tasks like Sudoku-Extreme [2] and ARC-AGI [3, 4]. TRM and its predecessor Hierarchical Reasoning
+Model (HRM) [2] represent an emerging architectural alternative to standard autoregressive reasoning
+models. Rather than autoregressively generating chains of token-level reasoning, they recursively
+refine a latent state. This approach produces a single deterministic answer per input, fitting well with
+tasks where the answer is unique.
+Despite their strong performance, their deterministic inference does not make full use of their
+capabilities. We show that many of TRM’s incorrect answers are from rollouts trapped in bad latent
+space basins (i.e., regions of the latent space which decode to incorrect answers and from which the
+deterministic recursions cannot escape). This observation, which aligns with recent mechanistic work
+on related models [5], suggests that TRM has the capabilities to solve significantly more problems
+but is limited by its standard inference procedure.
+Although each puzzle has a unique correct answer, many distinct latent trajectories can reach it. This
+is analogous to reasoning LLMs, where many reasoning trajectories can lead to the same unique
+answer. However, being non-deterministic, LLMs can be randomly sampled in order to form different
+trajectories (including Chains of Thought and actual answer). By then selecting a trajectory using
+a voting mechanism or based on the answer’s projected value (via a verifier), LLMs can leverage
+test-time compute to achieve very high accuracy [6]. We propose a way to achieve similar test-time
+scaling performance gains by sampling stochastic latent trajectories, each producing a deterministic
+decoded answer, and selecting among the answers using the model’s own Q head.
+TRM’s Q head is trained jointly (as a correctness classifier) with the rest of the network and is
+conventionally used only at training time for adaptive computation (ACT) [7]. It carries valuable
+information that the standard inference procedure discards.
+We propose Probabilistic TRM (PTRM), a test-time compute scaling framework that introduces a
+new width scaling axis. At inference we run K parallel rollouts per puzzle, each receiving Gaussian
+noise injected into the latent at every deep recursion step. The noise causes rollouts to follow different
+latent trajectories and settle in different basins. Among the resulting candidate answers, the Q head
+is used to select the one most likely to be correct. PTRM requires no training changes and no
+task-specific test-time augmentation, yet, as illustrated in Figure 1, delivers substantial accuracy
+gains across diverse reasoning benchmarks.
+
+2
+
+Background: Tiny Recursive Model
+
+Tiny Recursive Model (TRM) is a single network that iteratively refines a predicted answer y to a
+question x through recursive updates of a reasoning latent z. Specifically, a single latent recursion
+consists of n updates to the latent state z followed by one update to the predicted answer y, all using
+the same two-layer network fθ : z ← fθ (x + y + z) n times, then y ← fθ (y + z).
+fθ distinguishes the two update types by whether the input includes x. A deep recursion runs T
+latent recursions in sequence, with only the final one retaining gradients, allowing the model to
+leverage a large effective depth while keeping training efficient.
+Rather than doing one optimization step per sample, TRM is trained via deep supervision, which
+consists in keeping the previous latent state z and answer y as initialization (after being detached from
+the computational graph) for the next supervision step. This is done for up to Nsup supervision steps.
+The loss at each step is calculated using cross entropy between the predicted answer logits fO (y)
+(where fO is a linear output head) and the ground truth ytrue . This trains the network to progressively
+refine its prediction across reasoning steps. At inference, the recurrence can be unrolled for more
+steps than during training, providing a depth axis for test-time compute scaling (additional steps may
+correct otherwise-incorrect answers).
+Without halting mechanism during training, each puzzle stays in the mini-batch for Nsup supervision
+steps rather than being replaced after each one. To avoid wasting compute on already-solved samples,
+an Adaptive Computational Time (ACT) halting mechanism is used. This is done by adding a binary
+cross entropy loss between a halting logit q̂ = fQ (y) (where fQ is a linear Q head) and the binary exact
+2
+
+ Correct answer
+
+Incorrect answer
+
+PC 1 (58% var)
+1.0
+
+5.0
+0.9 2.5
+0.0
+0.8 2.5
+5.0
+0.7
+
+2.5
+0.0
+2.5
+5
+
+10
+
+Supervision step
+
+15
+
+0.9
+
+1.0
+
+5
+
+Cell accuracy
+
+Q value
+
+PC 1 (85% var)
+
+1.0
+
+5.0
+
+0
+
+Failure
+
+End
+
+PC 2 (8% var)
+
+PC 2 (36% var)
+PC 1 (84% var)
+
+5.0
+
+Start
+
+Delayed success
+
+PC 2 (15% var)
+
+Quick success
+
+Cell accuracy
+
+0.8
+
+0
+
+0.8
+0.7
+0
+
+5
+
+10
+
+Supervision step
+
+15
+
+0.6
+
+5
+0
+
+5
+
+10
+
+Supervision step
+
+15
+
+Figure 2: TRM Trajectory Modes. PCA projection of y (top) and Q value (solid, left axis) with cell
+accuracy (dashed, right axis) across supervision steps (bottom) for three PPBench puzzles, illustrating
+three trajectory modes (left to right): quick success, delayed success, and failure (Sec. 3). Latents are
+projected into the principal plane per puzzle, so PC axes are not comparable across plots. Trajectories
+fade from light (early steps) to dark (later steps). Circle marks the start and square marks end.
+
+correctness of the predicted answer ŷ = arg max fO (y): Lstep = CE(fO (y), ytrue ) + BCE(q̂, 1[ŷ =
+ytrue ]). The Q head thus allows the supervision loop to halt early on samples where sigmoid(q̂) > 0.5,
+improving data efficiency. During inference, the Q head is not used, and the model performs Nsup
+supervision steps to maximize answer correctness.
+While TRM is powerful, it sometimes gets stuck into incorrect solutions. In the next section, we will
+investigate such failures cases in order to determine a way to remedy them.
+
+3
+
+Problem: When Does TRM Fail?
+
+3.1
+
+Analysis of failures and successes
+
+We present observations about TRM that motivate our method. In this section, we train a TRM on
+multiple Pencil Puzzle Bench (PPBench) [8] puzzles and inspect the latent dynamics and Q head
+behavior across supervision steps on a held-out validation set. For each puzzle, we record the latent
+yt and the Q logit q̂t = fQ (yt ) at every supervision step t = 1, . . . , Nsup , project the latents into
+the principal plane (PCA per puzzle), and jointly plot the Q value alongside cell accuracy (fraction
+of correct cells in the predicted answer) over supervision steps. Figure 2 shows paired PCA and
+Q/cell-accuracy plots for three representative puzzles, illustrating three trajectory modes we observe:
+Quick success: the trajectory transitions in a few steps from its starting location to a convergence
+region and remains there. Cell accuracy and the Q value rise together and saturate near their maxima
+within the same few steps.
+Delayed success: the trajectory initially oscillates around one region and remains there for multiple
+supervision steps before sharply escaping to a different region where it converges. During the initial
+3
+
+ phase, the Q value is negative, and at the step where the trajectory escapes, both Q value and cell
+accuracy spike together.
+Failure: the trajectory oscillates in a bounded region without converging. Cell accuracy never reaches
+near 100%, and the Q value stays negative for all supervision steps.
+We refer to latent space regions that trajectories remain in across multiple supervision steps and
+exhibit similar cell accuracy throughout as basins. Basins where cell accuracy is near-maximal are
+good basins and basins where it is not are bad basins. Initially, failures and delayed successes behave
+similarly (both are caught in bad basins with negative Q). They diverge only later in their trajectories,
+when delayed successes find an escape to a good basin while failures remain stuck.
+3.2
+
+The Q head tracks trajectory quality
+6
+4
+
+Cell accuracy
+
+Q value
+
+2
+
+1.00
+0.95
+0.90
+0.85
+Incorrect (28)
+Correct (69)
+0.80
+Cell accuracy (right axis)
+0.75
+0.70
+0.65
+0.60
+10
+12
+14
+
+0
+2
+4
+6
+0
+
+2
+
+4
+
+6
+8
+Supervision step
+
+Figure 3: Q value follows cell accuracy across reasoning. Mean
+Q value (solid, left axis) and mean
+cell accuracy (dashed, right axis)
+over supervision steps, aggregated
+over 100 PPBench validation puzzles, separated by final correctness
+(green: correct, red: incorrect).
+
+Across all three modes (failures, delayed successes, and quick successes), we find that the Q head’s
+value closely tracks cell accuracy at every supervision step. To further confirm this, Figure 3
+aggregates trajectories from 100 PPBench validation puzzles, separating them by final-answer
+correctness. The aggregate view corroborates the per-puzzle observation: mean Q and mean cell
+accuracy rise together on correct trajectories and remain mostly flat on incorrect ones. Moreover, at
+convergence, the Q logit sharply separates the two populations where q̂ ≈ +6 (sigmoid ≈ 1) for
+correct trajectories and q̂ ≈ −6 (sigmoid ≈ 0) for incorrect ones. The Q head is therefore a reliable
+learned indicator of whether a trajectory has reached a good basin.
+Given that the Q head’s ability to distinguish good from bad trajectories, a natural question follows:
+can we leverage the Q head to identify better trajectories? The main challenge is that the standard
+TRM is inherently deterministic, and thus cannot be used to sample different trajectories for a given
+problem. In the next section, we will show that by simply adding Gaussian noise to the latent state,
+we can sample different parallel trajectories and leverage the Q head to pick the best one.
+
+4
+
+Method: Test-Time Compute Scaling via Stochastic Rollouts
+
+We propose Probabilistic TRM (PTRM), an inference-time procedure that makes the TRM recursion
+stochastic and selects the best of K resulting trajectories. PTRM requires no special training and
+can be readily applied to any pretrained TRM model. Furthermore it requires no task-specific
+augmentations. PTRM works as follows: at each supervision step, we add Gaussian noise (scaled by
+σ) to the latent state input. The Q head fQ scores each candidate latent output, and the one with the
+highest Q value is selected and then decoded using the model’s output head fO . The algorithm in
+Figure 4 (left) states this formally. PTRM offers two complementary benefits: 1) it enables trajectories
+to escape bad basins where deterministic TRM remains stuck, and 2) it introduces width as a new
+axis for test-time scaling.
+4.1
+
+Escaping bad basins
+
+In Sec. 3, we found that some failed deterministic trajectories are caught in bad solution basins in
+latent space, with no way to escape. PTRM lets us test whether stochastic perturbations are enough
+for some of the rollouts of a previously failed puzzle to reach a good solution basin. Figure 5 shows
+K=100 independent rollouts, from the same failed puzzle used in Figure 2 (which fails at K=1),
+4
+
+ PTRM Inference
+
+(a) Standard TRM (deterministic)
+
+1: Input: puzzle x, rollouts K,
+2: supervision steps D, noise scale σ
+3: for k = 1, . . . , K in parallel do
+
+answer
+
+depth axis: D deep recursion steps
+
+(k)
+
+(b) PTRM (ours): K stochastic rollouts + Q-head selection
++ϵ
+
+width axis: K rollouts
+
+(k)
+
+Initialize z0 , y0
+for t = 1, . . . , D do
+(k)
+zt−1 += ϵ, ϵ ∼ N (0, σ 2 I)
+(k)
+(k)
+(k)
+(k)
+7:
+zt , yt ← rec(x, zt−1 , yt−1 )
+8:
+end for
+(k)
+9:
+ŷ (k) ← arg max fO (yD )
+(k)
+(k)
+10:
+q̂ ← fQ (yD )
+11: end for ∗
+12: return ŷ (k ) , k ∗ = arg maxk q̂ (k)
+4:
+5:
+6:
+
+···
+
+puzzle
+
+k=
+
+puzzle
+
+1
+
++ϵ
+
++ϵ
+
+···
++ϵ
+
++ϵ
+
++ϵ
+
+···
+k=2
+k=
+K
+
+
++ϵ
+
+
+
++ϵ
+
++ϵ
+
+···
+deep recursion step
+
++ϵ
+
+arg maxk Qk
+
+final answer
+
+Gaussian noise injection
+
+Figure 4: Left: PTRM inference procedure (the rec() function refers to a deep recursion step). Right:
+PTRM mechanism. (a) Standard TRM: a single deterministic rollout. (b) PTRM: K stochastic latent
+rollouts with Gaussian noise ϵ at each deep recursion step, with the Q head selecting the final answer.
+projected into the principal plane. Most rollouts (92%) remain stuck in the same bad basin, while
+a minority (8%) escape to a distinct region in latent space and produce correct answers. We also
+observe that recurrent noise creates a per-rollout probability of escape: at K = 5 no rollouts escape,
+at K = 25 one does, and at K = 100 eight do. This confirms that noise provides the stochasticity
+needed to occasionally find an escape trajectory.
+4.2
+
+Width scaling
+
+Since more rollouts per puzzle compound the chance that at least one reaches a good basin, the
+number of rollouts K is a natural quantity to scale. Given K independent rollouts, pass@K (any
+rollout correct) is the oracle upper bound and best-Q@K (the rollout with highest q̂ is correct) is a
+metric available at inference without a correctness oracle. The choice of Q as selector is motivated by
+Sec. 3’s observation that Q accurately separates correct from incorrect trajectories (Figure 3).
+Figure 6 shows pass@K and best-Q@K as K grows, averaged over 3 seeds on the held-out PPBench
+validation set (sudoku, nurikabe, tapa, lightup, and heyawake). Both metrics rise from 76.4% at
+K = 1 to 89.5% at K = 100, a gain of 13 percentage points. Across all tested K, the gap between
+pass@K and best-Q@K stays under 1pp, making the Q head a strong verifier on this validation set.
+By contrast, mode@K (most frequent answer across rollouts) rises by only 1.3pp over the same
+range, showing that the width-scaling gains come mostly from the Q head’s ability to identify correct
+solutions even when they are rare.
+Interaction with depth scaling. Depth is another scaling axis already supported by TRM, which
+consists of running more deep recursions (supervision steps) at inference than the Nsup the model
+was trained on. On the deterministic baseline (K=1), tripling the depth from 16 to 48 steps raises
+PPBench validation accuracy from 76.4% to 79.5% (+3.1pp). At higher K, depth scaling only
+provides additional gains on specific puzzle types such as sudoku (+4pp at K = 100). Both depth
+and width scaling can be seen as ways to explore the model’s solution space. Since rollouts are
+independent and parallelizable while extra depth is sequential, width is the more practical scaling
+axis.
+PTRM unlocks a simple and task-agnostic recipe for scaling TRM test-time compute. The next
+section evaluates the method across multiple benchmarks and against several baselines, including
+frontier LLMs.
+
+5
+
+Experiments
+
+This section evaluates PTRM’s performance on diverse reasoning benchmarks. We compare against
+the deterministic TRM baseline, a non-recursive direct-prediction baseline, and frontier LLMs.
+Across several PPBench puzzles [8], Sudoku-Extreme [2], Maze-Hard [2], and ARC-AGI 2 [4],
+PTRM substantially boosts the performance of each pretrained TRM using only inference compute.
+5
+
+ Correct (8)
+Incorrect (92)
+Start
+End
+
+10
+8
+
+92.5
+
+PPBench accuracy (%)
+
+PC 2 (34% var)
+
+6
+4
+2
+0
+
+85.0
+82.5
+80.0
+77.5
+72.5
+
+2.5
+
+0.0
+
+2.5
+5.0
+7.5
+PC 1 (53% var)
+
+10.0
+
+12.5
+
+Figure 5: Stochastic rollouts escape bad
+basins. Principal plane projection of K =
+100 independent rollouts of the same failed
+puzzle as in Figure 2 (right). 92 rollouts
+remain caught in the bad basin (red). 8
+escape to a good basin and produce correct
+answers (green).
+
+5.1
+
+87.5
+
+75.0
+
+2
+4
+
+pass@K
+best-Q@K
+mode@K
+
+90.0
+
+1
+
+5
+10
+25
+Rollouts per puzzle K (log scale)
+
+100
+
+Figure 6: Width scaling. pass@K, best-Q@K,
+and mode@K as K grows, averaged over 3
+seeds on a held-out PPBench validation set. The
+Q head is a strong verifier on the tested puzzles,
+consistently outperforming selection of the most
+frequent answer.
+
+Setup
+
+Datasets. Pencil Puzzle Bench (PPBench) [8] consists of 62,231 constraint-satisfaction pencil puzzles
+(from 94 puzzle types). From the full PPBench dataset, 300 puzzles (15 puzzles from 20 types)
+selected by Waugh [8] are held out to form the golden set. From the remainder we hold out a
+fixed-size validation set of 100 puzzles per puzzle type (50 for tapa, due to its smaller base size),
+and the rest forms the training set. We filter all three sets to puzzles of six types (sudoku, lightup,
+nurikabe, shakashaka, heyawake, and tapa) of grid size 9×9 for sudoku, and 10×10 for the rest.
+We use the validation set to track performance during training and select the final checkpoint. We
+report per-puzzle accuracy on five of these types on the golden set (TRM already reaches 100% on
+shakashaka, so we omit it from the reported results), with aggregate scores sample-weighted across
+types. We also report results on the Sudoku-Extreme, Maze-Hard, and ARC-AGI 2 datasets.
+Models and inference. For each benchmark we use a standard TRM checkpoint. For SudokuExtreme we use the TRM-MLP variant (which the TRM paper showed to be stronger on Sudoku),
+and for the other datasets, we use TRM-Att. PTRM inference uses K parallel rollouts each running
+D supervision steps with Gaussian noise of scale σ added to the latent state at each supervision step.
+The selected configuration (K, D, σ) varies by benchmark and is given alongside each result. Metrics
+are averaged across three seeds.
+Baselines. To isolate the contribution of PTRM’s stochastic rollouts from the underlying backbone,
+we report standard TRM performance (the same checkpoint as PTRM ran deterministically). For
+each dataset, we report the performance of frontier LLMs. For Sudoku-Extreme, Maze-Hard, and
+ARC2 we additionally report the published direct prediction and TRM baselines from [1].
+Cost estimation. PPBench provides the dollar cost per attempt for each LLM. We convert PTRM’s
+wall-clock to a comparable dollar figure using a single H100 at $2.50/hr (standard cloud pricing [9])
+so that cost = $2.50 · tpuzzle /3600, where tpuzzle is the time (in seconds) to complete a puzzle.
+5.2
+5.2.1
+
+Pencil Puzzle Bench
+Per-puzzle accuracy
+
+Table 1 reports per-puzzle accuracy on the PPBench golden set. PTRM at K=100, D=48, σ=0.2
+raises aggregate best-Q@K from 62.6% to 91.2%. Increasing supervision depth alone (K=1, D=48)
+gives a small boost over the standard TRM baseline (K=1, D=16). Most of the gain comes
+from scaling width (stochastic rollouts). The largest improvements are on puzzle types where
+6
+
+ the deterministic baseline performed the worst (most headroom): sudoku improves from 46.7% to
+97.8% and tapa from 40.0% to 80.0%.
+% accuracy
+Direct prediction
+TRM (K=1, D=16)
+TRM (K=1, D=48)
+PTRM, best-Q@K (K=100, D=16)
+PTRM, best-Q@K (K=100, D=48)
+
+# Params sudoku lightup nurikabe heyawake
+27M
+7M
+7M
+7M
+7M
+
+0.0
+46.7
+57.8
+93.3
+97.8
+
+0.0
+87.5
+87.5
+100
+100
+
+0.0
+74.1
+74.1
+88.9
+88.9
+
+14.3
+85.7
+85.7
+85.7
+85.7
+
+tapa
+
+agg.
+
+0.0
+2.0
+40.0 62.6
+40.0 66.0
+80.0 89.8
+80.0 91.2
+
+Table 1: PPBench per-puzzle accuracy on the golden set. PTRM uses the same backbone as
+the deterministic TRM. Scaling depth alone (K=1, D=48) lifts aggregate accuracy by 3.4 points
+over the standard D=16 baseline. Combining depth with K=100 stochastic (σ=0.2) rollouts raises
+accuracy by 28.6 percentage points overall. The direct-prediction baseline is a larger transformer
+trained on the same data.
+
+5.2.2
+
+Comparison with frontier LLMs on golden set
+
+PPBench reported per-puzzle results for several frontier LLMs using two strategies: 1) direct response
+from a single prompt, and 2) multi-turn agentic strategy with verification. We report results for direct
+and any (best of any strategy attempted, including agentic). The agentic strategy gives the LLM
+substantially more resources than PTRM has access to. It provides the LLM the ability to iteratively
+verify each move with a perfect verifier. The direct strategy is the fairer comparison since, while
+it may use the model provider’s reasoning harness, it does not have direct access to a multi-turn
+verifier (the LLM could still self-verify by writing verification code within the same response). We
+additionally observe that the agentic strategy was applied selectively in the published PPBench data:
+across the LLMs we compare against, only 9.6% of direct failures on the golden set were retried
+with agentic. We restrict the comparison to the 7 strongest LLMs that attempted every puzzle in our
+golden set: claude-opus-4-6@thinking, gpt-5.2@xhigh, gemini-3.1-pro, gpt-5.2@high,
+claude-sonnet-4-6@thinking, gpt-5.2@medium, and kimi-k2.5. Table 2 lists the top 3 in
+each strategy block.
+We additionally report an ensemble score formed from these 7 LLMs where a puzzle counts as solved
+if at least one of them solved it via any strategy. This ensemble setup is deliberately stacked against
+PTRM. It assumes a perfect verifier since, if any of the 7 LLMs produced a correct answer under
+any strategy, the ensemble counts it as solved, even though in practice we would not have access
+to an oracle verifier. Although it is not deployable, we include the ensemble to demonstrate that
+even under these heavily favorable conditions, frontier LLMs fall well short of PTRM. Ensemble
+cost-per-attempt averages over the attempts of all 7 models on each puzzle, and cost-per-correct
+divides total cost by the number of puzzles the ensemble solved.
+Table 2 reports the comparison. PTRM exceeds the strongest single LLM (direct strategy) by 57
+points aggregate (91.2% vs. 34.7%), and exceeds the LLM ensemble by 36 points (91.2% vs. 55.1%)
+despite the ensemble’s stacked advantages. Cost per attempt is several orders of magnitude higher for
+LLMs than PTRM.
+5.3
+
+Sudoku-Extreme, Maze-Hard, and ARC-AGI-2
+
+For each benchmark we use the standard TRM checkpoint trained as described in [1] without
+modification (TRM-MLP for Sudoku-Extreme and TRM-Att for Maze-Hard and ARC-AGI-2).
+Table 3 summarizes results on all three.
+On Sudoku-Extreme, PTRM at K=100, D=64, σ=0.3 raises the deterministic baseline of 87.3% to
+99.06% pass@K and 98.75% best-Q@K, achieving state of the art.
+On Maze-Hard, PTRM at K=100, D=16, σ=1.0 reaches 95.63% pass@K, an 11.83 point gain
+over the 83.8% deterministic baseline. mode@K gives the best PTRM accuracy here at 86.73%
+(+2.93 points), with best-Q@K slightly behind at 85.17% (+1.37 points). While pass@K shows
+that PTRM is able to unlock several correct answers, the Q head identifies them less reliably than on
+the previous benchmarks.
+7
+
+ % accuracy
+
+tapa
+
+agg.
+
+$/att.
+
+$/corr.
+
+30.0
+50.0
+60.0
+
+24.5
+24.5
+34.7
+
+$0.40
+$1.79
+$2.91
+
+$1.62
+$7.29
+$8.40
+
+0.0
+0.0
+0.0
+
+40.0
+60.0
+70.0
+
+30.6
+34.7
+36.7
+
+$10.38
+$3.09
+$4.38
+
+$33.91
+$8.90
+$11.92
+
+0.0
+
+80.0
+
+55.1
+
+$2.66
+
+$38.51
+
+sudoku lightup nurikabe heyawake
+Direct
+
+gemini-3.1-pro
+gpt-5.2@xhigh
+claude-opus-4-6@thinking
+
+6.7
+20.0
+0.0
+
+75.0
+50.0
+87.5
+
+22.2
+0.0
+44.4
+
+0.0
+0.0
+0.0
+
+Any strategy (direct or agentic)†
+gemini-3.1-pro
+gpt-5.2@xhigh
+claude-opus-4-6@thinking
+
+6.7
+33.3
+0.0
+
+87.5
+75.0
+87.5
+
+33.3
+0.0
+44.4
+
+LLM ensemble†
+Any strategy (direct or agentic)
+
+46.7
+
+100
+
+44.4
+
+Ours, trained from scratch, 7M parameters
+PTRM, best-Q@K
+
+97.8
+
+100
+
+88.9
+
+85.7
+
+80.0 91.2 $0.001 $0.001
+
+Table 2: PTRM vs. frontier LLMs on PPBench golden. Per-puzzle accuracy and per-attempt /
+per-correct cost on the golden set. LLM costs are from PPBench. PTRM cost is estimated from H100
+wall-clock (Sec. 5.1). The direct and agentic blocks list the 3 highest scoring LLMs on aggregate,
+and the ensemble row uses all 7 listed in Sec. 5.2.2. † Assumes access to a perfect verifier.
+
+On ARC-AGI-2, the standard inference pipeline applies data augmentations and votes across them.
+PTRM adds K stochastic rollouts per augmentation. For selection, we pick the rollout with the
+highest Q value within each augmentation, then vote across augmentations as in the standard pipeline.
+With K=25 and σ=0.2, PTRM lifts pass@1 from 7.36% to 8.47% and pass@100 from 14.31% to
+15.97% over our deterministic TRM baseline, while matching it at pass@2.
+
+Sudoku-Extreme Maze-Hard
+ARC-AGI-2
+Acc. (%)
+Acc. (%) pass@1 pass@2 pass@100
+
+Method
+
+# Params
+
+HRM
+TRM
+
+27M
+5M / 7M†
+
+55.0
+87.4
+
+74.5
+85.3
+
+–
+–
+
+5.0
+7.8
+
+–
+–
+
+Ours
+Standard TRM, our reproduction 5M / 7M†
+PTRM
+5M / 7M†
+
+87.28
+98.75
+
+83.80
+86.73
+
+7.36
+8.47
+
+9.72
+9.72
+
+14.31
+15.97
+
+Table 3: Sudoku-Extreme, Maze-Hard, and ARC-AGI-2 results. For Sudoku-Extreme, K=100,
+D=64, σ=0.3. For Maze-Hard, K=100, D=16, σ=1.0. For ARC-AGI-2, K=25, D=16, σ=0.2.
+pass@k for ARC-AGI-2 reports the top-k predictions from the augmentation-voting pipeline. PTRM
+shows an accuracy improvement over standard TRM across all 3 benchmarks. † Following [1], 5M
+for Sudoku-Extreme (TRM-MLP), 7M for Maze-Hard and ARC-AGI-2 (TRM-Att).
+
+5.4
+
+Q head selection as σ grows
+
+With a higher σ value, PTRM finds many correct solutions that the deterministic inference misses.
+For instance, on Maze-Hard, the deterministic model solves 83.8% of puzzles, but PTRM raises
+pass@K to nearly 96%. The extent to which PTRM helps depends on the task, but on every dataset
+we tested, it unlocks correct solutions well beyond the deterministic model’s reach.
+TRM’s jointly trained Q head serves as a strong verifier on most tasks. On PPBench and SudokuExtreme, best-Q@K reaches values within a point of the saturated pass@K, so PTRM’s exploration
+translates directly into accuracy gains. On Maze-Hard, more exploration (higher σ) produces
+significantly more correct rollouts, but the existing Q head is not able to identify them, leaving
+performance on the table. The gap between best-Q@K and pass@K represents headroom for a
+stronger verifier which is left for future work. Appendix B reports the full σ sweep.
+8
+
+ 6
+
+Related Work
+
+A long line of work explores recursive computation for iterative reasoning and representation refinement. Early examples include Universal Transformers [10], Mixture-of-Recursions [11], Deep
+Thinking models [12, 13, 14], and HRM [2], all of which investigate the use of repeated computation
+steps to improve reasoning performance. More recent work has introduced methods to substantially
+accelerate TRM training [15], while TRM-style recursive architectures have also been extended to
+language modeling tasks [16].
+Building on this broader perspective of recursive computation, a growing body of work studies
+latent-space reasoning through the reuse of hidden states. Hao et al. [17] propose continuous
+“thinking tokens” derived from Chain-of-Thought (CoT) traces [18], which are autoregressively
+generated and appended to the model context, enabling reasoning directly in latent space without
+producing intermediate textual outputs. Similarly, Zhu et al. [19] formalize learning by superposition
+and demonstrate improvements on tasks such as graph reachability. By avoiding explicit token
+sampling and implicitly representing multiple reasoning trajectories, these approaches may mitigate
+the unfaithfulness and backtracking often observed in standard autoregressive reasoning [20, 21].
+Related to our work, Baek et al. [22] propose a generative version of TRM where the hidden state
+z is sampled instead of deterministic. This improves performance on multiple tasks, but requires
+retraining. Efstathiou and Balwani [23] (concurrent work) propose a similar test-time compute
+method where they only apply noise in the initial hidden state z, while we apply noise at every
+supervision step. Furthermore, they test their method on a small subset of the Sudoku-Extreme
+dataset, and treat it as a proof-of-concept that needs to be developed and tested further. Note that
+Baek et al. [22] also tested applying noise to the initial z with TRM and obtained negative results (no
+improvement in accuracy on two datasets).
+Our observations in Sec. 3 are consistent with the mechanistic analysis of Ren and Liu [5], who
+identify spurious fixed points in HRM’s latent dynamics on Sudoku-Extreme. Their method mitigates
+these attractors through a combination of task-specific training data augmentation, inference-time
+input perturbations, and model bootstrapping across training checkpoints, thereby effectively increasing test-time compute. However, these interventions are comparatively less general and less
+computationally efficient. In contrast, we observe analogous basin structure in TRM across multiple
+puzzle types and achieve attractor escape using a substantially simpler, task-agnostic mechanism:
+injecting Gaussian noise into the latent state at each supervision step while using a single deterministic
+checkpoint.
+
+7
+
+Conclusion
+
+In this work, we introduced Probabilistic TRM (PTRM), a novel test-time scaling paradigm for
+Tiny Recursive Models (TRM) through parallel exploration and selection. This approach scales
+test-time compute using width (K parallel rollouts), yielding substantially larger gains than depth
+scaling (increasing deep recursion steps) alone. PTRM requires no retraining and does not rely on
+task-specific data augmentations making it extremely easy to use and versatile.
+By scaling both width and depth, PTRM obtains significant gains in accuracy when tested on a wide
+selection of puzzles. On PPBench (Sudoku, Lightup, Nurikabe, Heyawake, Tapa puzzles), PTRM
+nearly obtains twice the accuracy (91.2%; $0.001 cost) of ensemble of SOTA LLMs (55.1%; $38.51
+cost) at less than 0.0001x the cost. Furthermore, PTRM improves accuracy on Sudoku (from 87.4%
+to 98.75%), Maze-Hard (from 83.80% to 86.73%), and ARC-AGI (from 7.8% to 8.47% pass@1).
+Limitations. Our experiments focus on reasoning puzzles rather than general tasks. We only test
+on a subset of PPBench puzzles. We are limited to puzzles with a small grid-size due to limited
+computational resources. It is not guaranteed that the method works as well for all types of problems
+(e.g., accuracy gains on ARC-AGI-2 and Heyawake are smaller).
+Future work. It would be interesting to understand why some puzzles benefit from test-time scaling
+more than others. We suspect that problems that are harder to verify (e.g., ARC-AGI-2) benefit less
+from PTRM because the Q head may struggle to distinguish correct solutions from incorrect ones.
+Developing stronger verifiers than the existing Q head is an interesting direction for future work.
+9
+
+ References
+[1] Alexia Jolicoeur-Martineau. Less is more: Recursive reasoning with tiny networks. arXiv
+preprint arXiv:2510.04871, 2025.
+[2] Guan Wang, Jin Li, Yuhao Sun, Xing Chen, Changling Liu, Yue Wu, Meng Lu, Sen Song, and
+Yasin Abbasi Yadkori. Hierarchical reasoning model. arXiv preprint arXiv:2506.21734, 2025.
+[3] François Chollet. On the measure of intelligence. arXiv preprint arXiv:1911.01547, 2019.
+[4] Francois Chollet, Mike Knoop, Gregory Kamradt, Bryan Landers, and Henry Pinkard. Arcagi-2: A new challenge for frontier ai reasoning systems. arXiv preprint arXiv:2505.11831,
+2025.
+[5] Zirui Ren and Ziming Liu. Are your reasoning models reasoning or guessing? a mechanistic
+analysis of hierarchical reasoning models. arXiv preprint arXiv:2601.10679, 2026.
+[6] Charlie Snell, Jaehoon Lee, Kelvin Xu, and Aviral Kumar. Scaling llm test-time compute optimally can be more effective than scaling model parameters. arXiv preprint arXiv:2408.03314,
+2024.
+[7] Alex Graves. Adaptive computation time for recurrent neural networks. arXiv preprint
+arXiv:1603.08983, 2016.
+[8] Justin Waugh. Pencil puzzle bench: A benchmark for multi-step verifiable reasoning. arXiv
+preprint arXiv:2603.02119, 2026.
+[9] Vast.ai. Rent h100 pcie gpus on vast.ai. https://vast.ai/pricing/gpu/H100-PCIE, 2026.
+Accessed: 2026-05-01.
+[10] Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Łukasz Kaiser. Universal transformers. arXiv preprint arXiv:1807.03819, 2018.
+[11] Sangmin Bae, Yujin Kim, Reza Bayat, Sungnyun Kim, Jiyoun Ha, Tal Schuster, Adam Fisch,
+Hrayr Harutyunyan, Ziwei Ji, Aaron Courville, et al. Mixture-of-recursions: Learning dynamic
+recursive depths for adaptive token-level computation. arXiv preprint arXiv:2507.10524, 2025.
+[12] Avi Schwarzschild, Eitan Borgnia, Arjun Gupta, Furong Huang, Uzi Vishkin, Micah Goldblum,
+and Tom Goldstein. Can you learn an algorithm? generalizing from easy to hard problems with
+recurrent networks. Advances in Neural Information Processing Systems, 34:6695–6706, 2021.
+[13] Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang, Micah Goldblum,
+and Tom Goldstein. End-to-end algorithm synthesis with recurrent networks: Extrapolation
+without overthinking. Advances in Neural Information Processing Systems, 35:20232–20242,
+2022.
+[14] Jay Bear, Adam Prugel-Bennett, and Jonathon Hare. Rethinking deep thinking: Stable learning
+of algorithms using lipschitz constraints. Advances in Neural Information Processing Systems,
+37:97027–97052, 2024.
+[15] Navid Hakimi. Form follows function: Recursive stem model. arXiv preprint arXiv:2603.15641,
+2026.
+[16] Yinxi Li, Jiaao Chen, Fang Wu, Jiakai Yu, Heli Qi, Weihao Xuan, Haokai Zhao, Pengyu Nie,
+Di Jin, and Xiangru Tang. Learning multi-step reasoning via persistent latent state propagation.
+In Workshop on Latent {\&} Implicit Thinking {\textendash} Going Beyond CoT Reasoning,
+2026.
+[17] Shibo Hao, Sainbayar Sukhbaatar, DiJia Su, Xian Li, Zhiting Hu, Jason Weston, and Yuandong
+Tian. Training large language models to reason in a continuous latent space. arXiv preprint
+arXiv:2412.06769, 2024.
+[18] Jason Wei, Xuezhi Wang, Dale Schuurmans, Maarten Bosma, Fei Xia, Ed Chi, Quoc V Le,
+Denny Zhou, et al. Chain-of-thought prompting elicits reasoning in large language models.
+Advances in neural information processing systems, 35:24824–24837, 2022.
+10
+
+ [19] Hanlin Zhu, Shibo Hao, Zhiting Hu, Jiantao Jiao, Stuart Russell, and Yuandong Tian. Reasoning
+by superposition: A theoretical perspective on chain of continuous thought. arXiv preprint
+arXiv:2505.12514, 2025.
+[20] Tamera Lanham, Anna Chen, Ansh Radhakrishnan, Benoit Steiner, Carson Denison, Danny
+Hernandez, Dustin Li, Esin Durmus, Evan Hubinger, Jackson Kernion, et al. Measuring
+faithfulness in chain-of-thought reasoning. arXiv preprint arXiv:2307.13702, 2023.
+[21] Yanda Chen, Joe Benton, Ansh Radhakrishnan, Jonathan Uesato, Carson Denison, John Schulman, Arushi Somani, Peter Hase, Misha Wagner, Fabien Roger, et al. Reasoning models don’t
+always say what they think. arXiv preprint arXiv:2505.05410, 2025.
+[22] Junyeob Baek, Mingyu Jo, Minsu Kim, Yoshua Bengio, and Sungjin Ahn. Generative recursive
+reasoning models. ICLR 2026 Workshop on AI with Recursive Self-Improvement, 2026.
+[23] Andreas Efstathiou and Aishwarya Balwani. Recursive reasoning as attractor landscape search:
+Mechanistic dynamics of the tiny recursive model. Workshop on Latent & Implicit Thinking – Going Beyond CoT Reasoning, 2026. URL https://openreview.net/forum?id=
+kKps9W1K7n.
+
+11
+
+ A
+
+Implementation Details
+
+A.1
+
+Compute
+
+We train and evaluate all models on a single NVIDIA H100 80GB GPU. PTRM introduces no
+additional training cost over standard TRM since it operates entirely at inference time.
+A.2
+
+Models
+
+All experiments use the standard TRM backbone [1] with the released architecture and training recipes.
+Following the TRM paper, we use the MLP variant (TRM-MLP, 5M parameters) for Sudoku-Extreme
+and the attention variant (TRM-Att, 7M parameters) for Maze-Hard, ARC-AGI-2, and PPBench.
+Layout and hyperparameters are unchanged from TRM.
+A.3
+
+PPBench dataset construction
+
+Sudoku-Extreme, Maze-Hard, and ARC-AGI-2 use the same checkpoints and data splits as TRM.
+The PPBench dataset is more recent and has previously been used only with frontier LLMs, so we
+detail how we built our training, validation, and golden splits.
+Source. PPBench contains 62,231 constraint-satisfaction pencil puzzles spanning 94 puzzle types.
+Of these, 300 puzzles (15 puzzles × 20 types) are held out as the golden benchmark set by Waugh [8].
+Filtering. From the remaining 61,931 puzzles we hold out a validation set by sampling 100 puzzles
+from each puzzle type (50 for tapa, due to its smaller base size), and the rest forms the training
+set. We then filter all three sets (training, validation, golden) to retain only puzzles of six types
+(sudoku, lightup, nurikabe, shakashaka, heyawake, tapa) at fixed grid sizes: 9×9 for sudoku
+and 10×10 for the others. Sudoku grids are padded with a pad token to 10×10, giving a uniform
+sequence length of seq_len = 100 across all six puzzle types. The deterministic TRM baseline
+reaches 100% accuracy on shakashaka, so we exclude it from per-puzzle accuracy reporting (no
+headroom to compare against PTRM).
+Augmentation. Each training puzzle is expanded into 10 examples using two augmentations: 1)
+trajectory sampling, where the input is set to a random intermediate solve state along the puzzle’s
+solution trajectory rather than always the empty initial grid, while the label is always the fully solved
+grid; and 2) dihedral transformation, where a random dihedral transformation of a square grid, among
+the 8 possibilities given by 4 rotations × 2 {identity, reflection}, is applied to both the input and the
+label. For each puzzle, the first example is the unaugmented (initial state, solved) pair. The remaining
+9 are randomly sampled (trajectory and dihedral transform). Validation and golden splits are not
+augmented.
+Resulting splits. The merged multi-type splits use a unified vocabulary of 294 tokens and seq_len =
+100. Per-type sample counts are reported in Table 4.
+puzzle type
+
+train
+
+val
+
+golden
+
+sudoku
+lightup
+nurikabe
+heyawake
+tapa
+shakashaka∗
+
+7,810
+9,504
+15,180
+42,108
+3,663
+20,702
+
+97
+65
+55
+70
+26
+62
+
+15
+8
+9
+7
+10
+12
+
+total
+
+98,967
+
+375
+
+61
+
+Table 4: Per-puzzle-type sample counts in the PPBench splits used in training and evaluation.
+∗
+Shakashaka is included in training but excluded from per-puzzle accuracy reporting because deterministic TRM already solves all evaluated shakashaka puzzles.
+
+12
+
+ B
+
+Noise Ablation
+
+We ablate the inference noise level σ on three benchmarks at K=25 (K=100 for Maze-Hard) and
+D=16 to keep the sweep tractable. For Sudoku-Extreme we randomly sample 1000 puzzles from the
+test set for the same reason. Figure 7 shows pass@K, best-Q@K, and mode@K as a function of σ,
+averaged over three random seeds.
+pass@K
+
+Sudoku-Extreme
+
+100
+
+mode@K
+
+K = 1 baseline
+
+Maze-Hard
+
+ARC-AGI-2 (within-aug)
+5.5
+
+96
+
+90
+
+accuracy (%)
+
+best-Q@K
+
+94
+
+80
+
+5.0
+
+92
+
+70
+
+90
+
+60
+
+88
+
+50
+
+86
+
+40
+
+84
+
+30
+0.0
+
+0.2
+
+0.4
+
+0.6
+
+0.8
+
+1.0
+
+82
+0.0
+
+4.5
+4.0
+3.5
+0.2
+
+0.4
+
+0.6
+
+0.8
+
+1.0
+
+0.0
+
+0.2
+
+0.4
+
+0.6
+
+0.8
+
+1.0
+
+Figure 7: pass@K, best-Q@K, and mode@K across σ per rollout batch. On every task,
+increasing the inference noise consistently produces more correct rollouts (pass@K, blue) up to
+a task-dependent σ value. The Q head (best-Q@K, orange) tracks the pass@K ceiling closely
+on Sudoku-Extreme and leaves a larger gap on Maze-Hard and ARC-AGI-2. The shaded region
+represents the verifier headroom (accuracy that a better verifier could extract). mode@K (green) has
+the edge over the Q head only on Maze-Hard. For ARC-AGI-2, metrics are per puzzle/augmentation
+to isolate the Q head’s verification abilities from the augmentation pipeline.
+On Maze-Hard pass@K climbs from 83.8% (deterministic) to nearly 96% by σ≈1.0 and then
+plateaus. On Sudoku-Extreme it is already near its ceiling at σ=0.1 and stays roughly flat across the
+sweep. On ARC-AGI-2 it peaks near σ=0.6 before declining. Q head selection nearly matches the
+ceiling (maximum pass@K) on Sudoku-Extreme while best-Q@K peaks at 98.5% (within a point of
+pass@K’s peak of 99.3%). On the other hand, the gap between best-Q@K and maximum pass@K
+is more pronounced on Maze-Hard and ARC-AGI-2 (headroom a stronger verifier could close).
+
+C
+
+Q-guided Langevin sampling
+
+We initially explored Langevin sampling (using the Q head gradient) as a more principled exploration
+mechanism than the Gaussian noise injection used in PTRM. The idea is to better guide the stochastic
+search by additionally steering each rollout (using the Q head gradient) toward regions of high Q
+value. We ultimately found that the gain from this approach was entirely attributable to the Langevin
+noise term, with the gradient component contributing nothing measurable on top of the equivalent
+recurrent noise of Sec. 4. We document the approach here as a negative result.
+Motivation. The Q head is trained as a correctness predictor over latent states. Let fQ (z) denote
+the head’s scalar output. We treated E(z) = − log sigmoid(fQ (z)) as an energy function over latent
+space. Empirical observations during early experiments suggested that regions of low E correspond
+to good basins from which the decoded answer is likely correct. PCA visualizations of the latent
+dynamics showed that ∇z fQ points toward the good-basin region from both good-basin (correct) and
+bad-basin (incorrect) latents (Figure 8). This made ∇z fQ look like a valuable direction along which
+to push latents.
+Method. We sample from the target distribution p(z) ∝ e−E(z) = sigmoid(fQ (z)) via Langevin
+dynamics where at the end of each deep recursion step t = 1, . . . , D we apply N Langevin steps to
+the latent,
+p
+z ← z − η ∇z E(z) + 2η ξ, ξ ∼ N (0, I),
+The number of Langevin steps N is the additional scaling axis under this scheme.
+13
+
+ t=0
+
+t=5
+
+t = 10
+
+t = 15
+
+Correct (21)
+Incorrect (4)
+Q
+
+Figure 8: y latents and their ∇z fQ gradients projected into the principal plane at several recursive/supervision steps, for multiple rollouts (using recurrent noise) of a single puzzle (correct rollouts
+in green, incorrect in red). Arrows are drawn at each latent in the direction of ∇z fQ . From both
+good-basin and bad-basin latents, gradients point toward the good-basin region. This visualization
+motivated the Langevin sampling experiment described below.
+Tractable gradient computation. TRM’s original Q head is a linear projection on a single token,
+fQ (y) = w⊤ y[:, 0]+b, so its gradient with respect to this head’s input is a constant vector independent
+of z. For ∇z fQ to be input-dependent, the gradient must flow back through the last latent recursion.
+This works but requires backpropagating through a full latent recursion at every Langevin step, which
+scales poorly with N . To make guidance tractable for large N , we replaced the linear Q head with
+an attention-pooled variant that reads the full latent and produces a scalar through a small nonlinear
+network. With this head, ∇z fQ can be computed by backpropagating through the head alone, which
+is ∼8× faster per step and does not sacrifice accuracy.
+The gain came from the noise,
+√ not the gradient. Comparing Langevin sampling against a noiseonly ablation (with the same 2η ξ, but with the −η ∇z E(z) term zeroed out) produced essentially
+identical accuracy at matched N . The gradient component contributed nothing measurable on
+top of the equivalent recurrent noise. This prompted us to focus on the noise-only formulation in
+Sec. 4, which is much more impactful since it is: 1) significantly simpler (no retraining, no test-time
+backpropagation), 2) applicable to any TRM checkpoint out of the box, and 3) equally effective.
+
+D
+
+Per-puzzle accuracy on the PPBench validation set
+
+The main paper reports per-puzzle accuracy on the PPBench golden set (Table 1) for direct comparability with the LLM evaluations from Waugh [8] who used that set. For a lower-variance complement,
+Table 5 reports results on our validation set (313 puzzles across the five reported types vs. 49 for
+golden). Trends match the golden-set results: depth scaling alone (K=1, D=48) provides a small lift,
+and combining depth with stochastic rollouts (K=100, D=48, σ=0.2) raises aggregate best-Q@K
+from 76.4% to 90.4%, a 14.0 percentage-point improvement. The biggest gains again are on puzzles
+where the deterministic baseline has the most headroom (tapa ∼ 40% to 71.8%, sudoku ∼ 69%
+to 93.3%). Types where the baseline is already near ceiling (heyawake at 96.7%) increase only
+marginally.
+% accuracy
+Direct prediction
+TRM (K=1, D=16)
+TRM (K=1, D=48)
+PTRM, best-Q@K (K=100, D=48)
+
+# Params sudoku lightup nurikabe heyawake
+27M
+7M
+7M
+7M
+
+0.0
+68.7
+74.0
+93.3
+
+10.0
+83.3
+84.0
+93.3
+
+4.0
+76.0
+76.7
+84.7
+
+14.0
+96.7
+98.0
+100
+
+tapa
+
+agg.
+
+0.0
+6.2
+39.7 76.4
+41.0 78.3
+71.8 90.4
+
+Table 5: PPBench per-puzzle accuracy on the validation set. PTRM uses the same backbone as the
+deterministic TRM. Results on the larger validation set follow the same trends as on the golden set.
+
+14
+
+ \ No newline at end of file
diff --git a/papers/trm_2510.04871.pdf b/papers/trm_2510.04871.pdf
new file mode 100644
index 0000000..4307790
--- /dev/null
+++ b/papers/trm_2510.04871.pdf
Binary files differ