summaryrefslogtreecommitdiff
path: root/diag
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-17 11:19:27 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-17 11:19:27 -0500
commitd12722525fc010a3910b5152c72654a2ade5eac4 (patch)
treeadbfd19bef487e1c959856d9bf7f881709684a7e /diag
Initial import
Diffstat (limited to 'diag')
-rw-r--r--diag/__init__.py0
-rw-r--r--diag/aggregate.py54
-rw-r--r--diag/cin_color.py144
-rw-r--r--diag/datasets.py70
-rw-r--r--diag/esan_color.py145
-rw-r--r--diag/ff_color.py75
-rw-r--r--diag/lyap.py83
-rw-r--r--diag/models.py42
-rw-r--r--diag/peptides_depth.py66
-rw-r--r--diag/ppgn_color.py209
-rw-r--r--diag/ptrm_color.py85
-rw-r--r--diag/run_archA.sh16
-rw-r--r--diag/run_archB.sh21
-rw-r--r--diag/run_cin.sh8
-rw-r--r--diag/run_color.sh15
-rw-r--r--diag/run_cycle.sh13
-rw-r--r--diag/run_diag.sh16
-rw-r--r--diag/run_esan.sh8
-rw-r--r--diag/run_ff.sh8
-rw-r--r--diag/run_le.sh12
-rw-r--r--diag/run_pe.sh24
-rw-r--r--diag/run_pe2.sh18
-rw-r--r--diag/run_pe3.sh23
-rw-r--r--diag/run_pna.sh13
-rw-r--r--diag/run_ppgn.sh8
-rw-r--r--diag/run_real.sh12
-rw-r--r--diag/run_rec.sh13
-rw-r--r--diag/run_seeds.sh22
-rw-r--r--diag/selftest_wl.py53
-rw-r--r--diag/train_color.py347
-rw-r--r--diag/train_cycle.py183
-rw-r--r--diag/train_diag.py161
-rw-r--r--diag/train_real.py139
-rw-r--r--diag/train_rec.py174
-rw-r--r--diag/wl.py166
35 files changed, 2446 insertions, 0 deletions
diff --git a/diag/__init__.py b/diag/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/diag/__init__.py
diff --git a/diag/aggregate.py b/diag/aggregate.py
new file mode 100644
index 0000000..b0f737a
--- /dev/null
+++ b/diag/aggregate.py
@@ -0,0 +1,54 @@
+"""Aggregate multi-seed coloring results -> mean+/-std per (grad_mode, pe, contract)."""
+import glob, json
+import numpy as np
+from collections import defaultdict
+
+R = '/home/yurenh2/rrog/runs'
+
+
+def ms(xs):
+ a = np.array([x for x in xs if x is not None], dtype=float)
+ return f"{a.mean():.3f}±{a.std():.3f} (n={len(a)})" if len(a) else "—"
+
+
+def load(pat):
+ out = []
+ for f in glob.glob(pat):
+ try:
+ out.append(json.load(open(f)))
+ except Exception:
+ pass
+ return out
+
+
+def key(d):
+ return (d.get('conv', 'gin'), d.get('pe'), d.get('grad_mode'), 'ctr' if d.get('contract') else '-')
+
+
+solve, le, ml = defaultdict(list), defaultdict(list), defaultdict(list)
+pk, ls, au, det = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
+
+for d in load(f"{R}/color_*.json"):
+ if 'solve_rate' in d and d.get('pe') is not None:
+ solve[key(d)].append(d['solve_rate'])
+for d in load(f"{R}/le_color_*.json"):
+ le[key(d)].append(d.get('auroc'))
+ ml[key(d)].append(d.get('mean_lam'))
+for d in load(f"{R}/ptrm_color_*.json"):
+ k = key(d); det[k].append(d.get('det'))
+ s2 = d.get('sigmas', {}).get('0.2')
+ if s2:
+ pk[k].append(s2.get('passk')); ls[k].append(s2.get('lamsel')); au[k].append(s2.get('auroc'))
+
+print("=== best solve_rate (deterministic, EMA) ===")
+for k in sorted(solve, key=str):
+ print(f" {k}: {ms(solve[k])}")
+print("=== LE AUROC(fail|lambda1) ===")
+for k in sorted(le, key=str):
+ print(f" {k}: {ms(le[k])}")
+print("=== LE mean_lambda1 (forced-contraction dose) ===")
+for k in sorted(ml, key=str):
+ print(f" {k}: {ms(ml[k])}")
+print("=== PTRM sigma=0.2 ===")
+for k in sorted(pk, key=str):
+ print(f" {k}: det {ms(det[k])} | pass@K {ms(pk[k])} | lambda-sel {ms(ls[k])} | AUROC {ms(au[k])}")
diff --git a/diag/cin_color.py b/diag/cin_color.py
new file mode 100644
index 0000000..324215f
--- /dev/null
+++ b/diag/cin_color.py
@@ -0,0 +1,144 @@
+"""H4: Recursive CIN-lite (topological / cell-complex) backbone on graph 3-coloring.
+
+Augment each graph with ring 2-cells from the cycle basis: add one hypernode per basis cycle,
+connected to its member nodes. Messages flow node->ring->node (topological message passing over
+rings). Run the shared recursive GIN on the augmented (nodes + ring-cells) graph; decode colors
+on the ORIGINAL nodes only. Self-contained: train (EMA, best solve) + LE + PTRM; writes
+color_/le_/ptrm_ JSON (conv='cin') for diag/aggregate.py.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/cin_color.py --grad_mode full --seed 0
+"""
+import argparse, json, os, time
+import numpy as np
+import networkx as nx
+import torch
+import torch.nn.functional as F
+from torch_geometric.data import Data, Batch
+from diag.train_color import make_split, featurize, RecGINColor, lyap1, OUT
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+N = 50
+
+
+def dense_A(edge_index, n):
+ A = torch.zeros(n, n)
+ if edge_index.shape[1]:
+ A[edge_index[0], edge_index[1]] = 1.0
+ return torch.maximum(A, A.t())
+
+
+def augment(g):
+ n = g['n']; ei = g['edge_index']
+ G = nx.Graph(); G.add_nodes_from(range(n))
+ if ei.shape[1]:
+ G.add_edges_from(ei.t().tolist())
+ rings = nx.cycle_basis(G)
+ R = len(rings)
+ src, dst = ei[0].tolist(), ei[1].tolist()
+ for r, cyc in enumerate(rings):
+ rn = n + r
+ for node in cyc:
+ src += [node, rn]; dst += [rn, node]
+ aug_ei = torch.tensor([src, dst], dtype=torch.long) if src else torch.zeros((2, 0), dtype=torch.long)
+ d = g['rfeat'].shape[1]
+ nf = torch.cat([g['rfeat'], torch.tensor([[1.0, 0.0]]).repeat(n, 1)], dim=1)
+ rf = torch.cat([torch.zeros(R, d), torch.tensor([[0.0, 1.0]]).repeat(R, 1)], dim=1) if R else torch.zeros(0, d + 2)
+ return Data(x=torch.cat([nf, rf], dim=0), edge_index=aug_ei, num_nodes=n + R)
+
+
+def conf_of(logits, A):
+ col = logits.argmax(-1)
+ return int(((col.unsqueeze(0) == col.unsqueeze(1)) & (A > 0)).sum().item() // 2)
+
+
+@torch.no_grad()
+def solve_rate(model, graphs, dev, sample=300):
+ model.eval(); solved = 0
+ for g in graphs[:sample]:
+ d = g['aug'].to(dev)
+ lg = model(d.x, d.edge_index)[-1][:N]
+ solved += int(conf_of(lg, g['A'].to(dev)) == 0)
+ return solved / len(graphs[:sample])
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--M', type=int, default=16)
+ ap.add_argument('--seed', type=int, default=0); ap.add_argument('--sigma', type=float, default=0.2)
+ ap.add_argument('--K', type=int, default=16); ap.add_argument('--n_graphs', type=int, default=150)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ rng = np.random.default_rng(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ for g in tr + te:
+ g['A'] = dense_A(g['edge_index'], g['n']); g['aug'] = augment(g)
+ in_dim = tr[0]['rfeat'].shape[1] + 2
+ model = RecGINColor(in_dim, 128, 3, grad_mode=args.grad_mode, conv='gin').to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+
+ t0 = time.time(); best = -1; best_state = None
+ for ep in range(args.epochs):
+ model.train(); order = rng.permutation(len(tr))
+ for s0 in range(0, len(order) - args.M, args.M):
+ sel = order[s0:s0 + args.M]
+ b = Batch.from_data_list([tr[i]['aug'] for i in sel]).to(dev)
+ opt.zero_grad()
+ logits = model(b.x, b.edge_index, b.batch)[-1]
+ orig = torch.stack([logits[b.ptr[gi]:b.ptr[gi] + N] for gi in range(len(sel))]) # [M,N,k]
+ A = torch.stack([tr[i]['A'] for i in sel]).to(dev)
+ p = F.softmax(orig, -1)
+ loss = (torch.einsum('bik,bjk->bij', p, p) * A).sum() / (A.sum() + 1e-9)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr = solve_rate(model, te, dev)
+ if sr > best:
+ best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(bk)
+ print(f"ep{ep+1} solve={sr:.3f}", flush=True)
+ model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval()
+
+ nstep = model.n_sup * model.T
+ lams, fails = [], []; passk = lamsel = rand = 0; Lr, Sr = [], []
+ for gi, g in enumerate(te[:args.n_graphs]):
+ d = g['aug'].to(dev); A = g['A'].to(dev)
+ lam0 = lyap1(model, d.x, d.edge_index, nstep, dev, seed=gi)
+ c0 = conf_of(model(d.x, d.edge_index)[-1][:N], A)
+ lams.append(lam0); fails.append(int(c0 > 0))
+ confs, rl = [], []
+ for j in range(args.K):
+ confs.append(conf_of(model(d.x, d.edge_index, noise=True)[-1][:N], A))
+ rl.append(lyap1(model, d.x, d.edge_index, nstep, dev, seed=1000 * gi + j))
+ confs, rl = np.array(confs), np.array(rl); sv = confs == 0
+ passk += int(sv.any()); lamsel += int(sv[rl.argmin()]); rand += int(sv[0])
+ Lr += rl.tolist(); Sr += sv.tolist()
+ lams, fails = np.array(lams), np.array(fails); s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ n = len(te[:args.n_graphs])
+ print(f"[cin/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} "
+ f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)")
+ base = f"cin_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}"
+ com = {'conv': 'cin', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed}
+ json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'))
+ json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())}, open(os.path.join(OUT, f"le_color_{base}.json"), 'w'))
+ json.dump({**com, 'det': 1 - float(fails.mean()),
+ 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n}}},
+ open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'))
+ print(" wrote", base)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/datasets.py b/diag/datasets.py
new file mode 100644
index 0000000..cbd236e
--- /dev/null
+++ b/diag/datasets.py
@@ -0,0 +1,70 @@
+"""Synthetic graph datasets for the 1-WL diagnosis."""
+import numpy as np
+import networkx as nx
+
+
+def _nx_to_edge_index(G):
+ G = nx.convert_node_labels_to_integers(G)
+ n = G.number_of_nodes()
+ if G.number_of_edges() == 0:
+ return n, np.zeros((2, 0), dtype=np.int64)
+ e = np.array(list(G.edges()), dtype=np.int64).T
+ ei = np.concatenate([e, e[::-1]], axis=1) # undirected -> both directions
+ return n, ei
+
+
+def circulant(N, offsets):
+ G = nx.Graph()
+ G.add_nodes_from(range(N))
+ for i in range(N):
+ for s in offsets:
+ G.add_edge(i, (i + s) % N)
+ return G
+
+
+CSL_SKIPS = [2, 3, 4, 5, 6, 9, 11, 12, 13, 16] # 10 classes, N=41 (Murphy et al. 2019)
+
+
+def build_csl(n_per_class=15, N=41, seed=0):
+ """Circular Skip Links: all graphs 4-regular -> 1-WL collapses to one color (pure H2 anchor)."""
+ rng = np.random.default_rng(seed)
+ data = []
+ for cls, s in enumerate(CSL_SKIPS):
+ for _ in range(n_per_class):
+ G = circulant(N, [1, s])
+ perm = rng.permutation(N)
+ G = nx.relabel_nodes(G, {i: int(perm[i]) for i in range(N)})
+ n, ei = _nx_to_edge_index(G)
+ data.append({'n': n, 'edge_index': ei, 'y': cls})
+ return data
+
+
+def build_triangle_count(n_graphs=600, n_nodes=20, kind='regular', deg=3, p=0.2, seed=0):
+ """Graph-level triangle-count regression. 1-WL cannot count triangles -> measurable H2 floor."""
+ rng = np.random.default_rng(seed)
+ data, tries = [], 0
+ while len(data) < n_graphs and tries < n_graphs * 30:
+ tries += 1
+ sd = int(rng.integers(1 << 30))
+ try:
+ G = (nx.random_regular_graph(deg, n_nodes, seed=sd) if kind == 'regular'
+ else nx.gnp_random_graph(n_nodes, p, seed=sd))
+ except Exception:
+ continue
+ tri = sum(nx.triangles(G).values()) // 3
+ n, ei = _nx_to_edge_index(G)
+ data.append({'n': n, 'edge_index': ei, 'y': float(tri)})
+ return data
+
+
+def canonical_pairs():
+ """Graphs for instrument self-test (known 1-WL outcomes)."""
+ pairs = [('C6', nx.cycle_graph(6)),
+ ('2C3', nx.disjoint_union(nx.cycle_graph(3), nx.cycle_graph(3))),
+ ('P4', nx.path_graph(4)),
+ ('K1,3', nx.star_graph(3))]
+ out = {}
+ for name, G in pairs:
+ n, ei = _nx_to_edge_index(G)
+ out[name] = {'n': n, 'edge_index': ei, 'tri': sum(nx.triangles(G).values()) // 3}
+ return out
diff --git a/diag/esan_color.py b/diag/esan_color.py
new file mode 100644
index 0000000..7be050c
--- /dev/null
+++ b/diag/esan_color.py
@@ -0,0 +1,145 @@
+"""H3: Recursive ESAN (subgraph GNN, DS-GNN node-marking bag) on graph 3-coloring.
+
+Per graph, pick S anchor nodes; each view = graph + a 1-hot mark on the anchor. Run the SHARED
+recursive GIN on all views, average node-logits over views (DeepSets). Marking breaks node
+symmetry -> >1-WL. Self-contained: train (EMA, best solve) + LE (lambda on a marked view,
+bucket by aggregate solve) + PTRM (K noisy aggregate forwards); writes color_/le_/ptrm_ JSON
+(conv='esan') for diag/aggregate.py.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/esan_color.py --grad_mode full --seed 0
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch_geometric.data import Data, Batch
+from diag.train_color import make_split, featurize, RecGINColor, lyap1, OUT
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+S = 4
+
+
+def dense_A(edge_index, n):
+ A = torch.zeros(n, n)
+ if edge_index.shape[1]:
+ A[edge_index[0], edge_index[1]] = 1.0
+ return torch.maximum(A, A.t())
+
+
+def views_of(g, anchors):
+ out = []
+ for a in anchors:
+ mark = torch.zeros(g['n'], 1); mark[a] = 1.0
+ out.append(Data(x=torch.cat([g['rfeat'], mark], dim=1), edge_index=g['edge_index'], num_nodes=g['n']))
+ return out
+
+
+def anchors_for(g, rng):
+ return rng.choice(g['n'], size=min(S, g['n']), replace=False)
+
+
+def esan_logits(model, g, dev, anchors, noise=False):
+ b = Batch.from_data_list(views_of(g, anchors)).to(dev)
+ out = model(b.x, b.edge_index, b.batch, noise=noise)[-1] # [S*n, k]
+ return out.view(len(anchors), g['n'], -1).mean(0) # [n, k]
+
+
+def conf_of(logits, A):
+ col = logits.argmax(-1)
+ return int(((col.unsqueeze(0) == col.unsqueeze(1)) & (A > 0)).sum().item() // 2)
+
+
+@torch.no_grad()
+def solve_rate(model, graphs, dev, rng, sample=300):
+ model.eval(); solved = 0
+ for g in graphs[:sample]:
+ solved += int(conf_of(esan_logits(model, g, dev, anchors_for(g, rng)), g['A'].to(dev)) == 0)
+ return solved / len(graphs[:sample])
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--M', type=int, default=16)
+ ap.add_argument('--seed', type=int, default=0); ap.add_argument('--sigma', type=float, default=0.2)
+ ap.add_argument('--K', type=int, default=16); ap.add_argument('--n_graphs', type=int, default=150)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ rng = np.random.default_rng(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ for g in tr + te:
+ g['A'] = dense_A(g['edge_index'], g['n'])
+ in_dim = tr[0]['rfeat'].shape[1] + 1
+ model = RecGINColor(in_dim, 128, 3, grad_mode=args.grad_mode, conv='gin').to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+
+ t0 = time.time(); best = -1; best_state = None
+ for ep in range(args.epochs):
+ model.train(); order = rng.permutation(len(tr))
+ for s0 in range(0, len(order) - args.M, args.M):
+ sel = order[s0:s0 + args.M]
+ views, As = [], []
+ for i in sel:
+ g = tr[i]; views += views_of(g, anchors_for(g, rng)); As.append(g['A'])
+ b = Batch.from_data_list(views).to(dev)
+ opt.zero_grad()
+ logits = model(b.x, b.edge_index, b.batch, noise=False)[-1].view(args.M, S, 50, 3).mean(1)
+ Ab = torch.stack(As).to(dev)
+ p = F.softmax(logits, -1)
+ loss = (torch.einsum('bik,bjk->bij', p, p) * Ab).sum() / (Ab.sum() + 1e-9)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr = solve_rate(model, te, dev, rng)
+ if sr > best:
+ best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(bk)
+ print(f"ep{ep+1} solve={sr:.3f}", flush=True)
+ model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval()
+
+ lams, fails = [], []; passk = lamsel = rand = 0; Lr, Sr = [], []
+ nstep = model.n_sup * model.T
+ for gi, g in enumerate(te[:args.n_graphs]):
+ anc = anchors_for(g, rng)
+ mark = torch.zeros(g['n'], 1); mark[anc[0]] = 1.0
+ xin = torch.cat([g['rfeat'], mark], dim=1).to(dev); ei = g['edge_index'].to(dev)
+ lam0 = lyap1(model, xin, ei, nstep, dev, seed=gi)
+ c0 = conf_of(esan_logits(model, g, dev, anc), g['A'].to(dev))
+ lams.append(lam0); fails.append(int(c0 > 0))
+ confs, rl = [], []
+ for j in range(args.K):
+ confs.append(conf_of(esan_logits(model, g, dev, anc, noise=True), g['A'].to(dev)))
+ rl.append(lyap1(model, xin, ei, nstep, dev, seed=1000 * gi + j))
+ confs, rl = np.array(confs), np.array(rl); sv = confs == 0
+ passk += int(sv.any()); lamsel += int(sv[rl.argmin()]); rand += int(sv[0])
+ Lr += rl.tolist(); Sr += sv.tolist()
+ lams, fails = np.array(lams), np.array(fails); s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ n = len(te[:args.n_graphs])
+ print(f"[esan/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} "
+ f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)")
+ base = f"esan_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}"
+ com = {'conv': 'esan', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed}
+ json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'))
+ json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())},
+ open(os.path.join(OUT, f"le_color_{base}.json"), 'w'))
+ json.dump({**com, 'det': 1 - float(fails.mean()),
+ 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n}}},
+ open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'))
+ print(" wrote", base)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/ff_color.py b/diag/ff_color.py
new file mode 100644
index 0000000..29c3b45
--- /dev/null
+++ b/diag/ff_color.py
@@ -0,0 +1,75 @@
+"""Control: plain FEEDFORWARD (independent-weight, NO recursion, NO deep-supervision) GNN on
+graph 3-coloring, to test whether the recursive gcn/gat collapse (.003/.04) is caused by the
+RECURSION (shared-weight repeated -> oversmoothing) or is intrinsic to gcn/gat on coloring.
+
+Sweep conv in {gcn,gat} x depth L in {4,8,16}. L=4 ~ normal usage; L=16 tests whether deep
+feedforward also oversmooths. If shallow FF colors well but recursive collapses -> recursion's
+fault. If FF is also ~0 at all L -> intrinsic.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ff_color.py --conv gcn --L 4 --seed 0
+"""
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from diag.train_color import make_split, featurize, make_conv, conflict_loss, solve_stats
+
+
+class FF(nn.Module):
+ def __init__(self, in_dim, hidden, k, L, conv):
+ super().__init__()
+ self.conv_type = conv
+ self.lin_in = nn.Linear(in_dim, hidden)
+ self.convs = nn.ModuleList([make_conv(conv, hidden) for _ in range(L)])
+ self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(L)])
+ self.head = nn.Linear(hidden, k)
+
+ def forward(self, xin, ei, batch=None, noise=False):
+ h = self.lin_in(xin)
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, ei)).relu()
+ return [self.head(h)]
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gcn', 'gat', 'gin', 'sage'], default='gcn')
+ ap.add_argument('--L', type=int, default=4)
+ ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ in_dim = tr[0]['xin'].shape[1]
+ trl = DataLoader([Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr],
+ batch_size=32, shuffle=True, drop_last=True)
+ model = FF(in_dim, 128, 3, args.L, args.conv).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ best = -1
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = conflict_loss(model(b.x, b.edge_index)[-1], b.edge_index)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 30 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr, _ = solve_stats(model, te, dev, sample=300)
+ best = max(best, sr); model.load_state_dict(bk)
+ print(f"[ff/{args.conv}/L{args.L}/s{args.seed}] solve={best:.3f}", flush=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/lyap.py b/diag/lyap.py
new file mode 100644
index 0000000..93b90bf
--- /dev/null
+++ b/diag/lyap.py
@@ -0,0 +1,83 @@
+"""LE diagnostic for the recursive (TRM-ish) GNN — ports the flossing finding to graphs.
+
+Per-graph top Lyapunov exponent lambda1 of the recursion z <- block(z+h0), via Benettin
+power-iteration on a single tangent vector (JVP + renormalize, accumulate log-growth) over
+the model's n_sup*T recursion steps. Bucket graphs by success/failure (rounded ring counts
+exact) and compare lambda1 distributions + AUROC(fail | lambda1) — mirroring
+plot_trm_lyap_hist.py. Hypothesis: failed graphs are MORE chaotic (higher lambda1).
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/lyap.py --ckpt runs/ckpt_rec_full_..._s0.pt
+"""
+import argparse
+import numpy as np
+import torch
+from diag.train_rec import RecGIN
+from diag.train_cycle import prepare
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+
+def build(ck, dev):
+ c = ck['cfg']
+ m = RecGIN(c['n_atom'], c['hidden'], c['T'], c['n_sup'], 0.0, grad_mode=c['grad_mode']).to(dev)
+ m.load_state_dict(ck['state']); m.eval()
+ return m, c
+
+
+def lyap1(model, x, ei, n_steps, dev, seed=0):
+ g = torch.Generator(device=dev).manual_seed(seed)
+ h0 = model.emb(x).detach()
+ z = torch.zeros_like(h0)
+ v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
+ def step_fn(zz):
+ return model.block(zz + h0, ei)
+ lam = 0.0
+ for _ in range(n_steps):
+ z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
+ z = z_next.detach()
+ nv = Jv.norm()
+ lam += torch.log(nv + 1e-12).item()
+ v = (Jv / (nv + 1e-12)).detach()
+ return lam / n_steps
+
+
+@torch.no_grad()
+def predict(model, x, ei, dev):
+ batch = torch.zeros(x.size(0), dtype=torch.long, device=dev)
+ preds, _ = model(x, ei, batch, noise=False)
+ return preds[-1].view(-1)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--ckpt', required=True)
+ ap.add_argument('--n_graphs', type=int, default=300)
+ args = ap.parse_args()
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ ck = torch.load(args.ckpt, weights_only=False)
+ model, cfg = build(ck, dev)
+ ymu, ysd = ck['ymu'].to(dev), ck['ysd'].to(dev)
+ te = prepare('test')
+ n_steps = cfg['n_sup'] * cfg['T']
+
+ lams, fails = [], []
+ for i, r in enumerate(te[:args.n_graphs]):
+ x = r['x'].to(dev); ei = r['edge_index'].to(dev)
+ p = predict(model, x, ei, dev) * ysd + ymu # raw [2]
+ y = r['y'].to(dev) # raw [2]
+ fails.append(int(not torch.all(p.round() == y.round()).item()))
+ lams.append(lyap1(model, x, ei, n_steps, dev, seed=i))
+ lams, fails = np.array(lams), np.array(fails)
+ s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ print(f"[{cfg['grad_mode']}] n={len(lams)} fail_rate={fails.mean():.2f} | "
+ f"lambda1 SUCC mean {s.mean():+.4f} std {s.std():.4f} (n={len(s)}) | "
+ f"FAIL mean {f.mean():+.4f} std {f.std():.4f} (n={len(f)}) | "
+ f"sep(fail-succ)={f.mean()-s.mean() if len(s) and len(f) else float('nan'):+.4f} | "
+ f"AUROC(fail|lambda1)={auc:.3f} | mean_lambda1={lams.mean():+.4f}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/models.py b/diag/models.py
new file mode 100644
index 0000000..0aa3106
--- /dev/null
+++ b/diag/models.py
@@ -0,0 +1,42 @@
+"""GIN (1-WL-tight) and GCN (<1-WL) backbones for the diagnosis."""
+import torch.nn as nn
+from torch_geometric.nn import GINConv, GCNConv, global_add_pool, global_mean_pool
+
+
+def _mlp(d_in, d_hid, d_out):
+ return nn.Sequential(nn.Linear(d_in, d_hid), nn.BatchNorm1d(d_hid), nn.ReLU(),
+ nn.Linear(d_hid, d_out))
+
+
+class GIN(nn.Module):
+ """Sum aggregation + MLP update -> injective on multisets -> matches 1-WL."""
+ def __init__(self, in_dim, hidden=64, layers=4, out_dim=10):
+ super().__init__()
+ self.convs = nn.ModuleList()
+ d = in_dim
+ for _ in range(layers):
+ self.convs.append(GINConv(_mlp(d, hidden, hidden), train_eps=True))
+ d = hidden
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
+
+ def forward(self, x, edge_index, batch):
+ for conv in self.convs:
+ x = conv(x, edge_index).relu()
+ return self.head(global_add_pool(x, batch))
+
+
+class GCN(nn.Module):
+ """Mean (normalized) aggregation -> non-injective -> strictly below 1-WL (reference baseline)."""
+ def __init__(self, in_dim, hidden=64, layers=4, out_dim=10):
+ super().__init__()
+ self.convs = nn.ModuleList()
+ d = in_dim
+ for _ in range(layers):
+ self.convs.append(GCNConv(d, hidden))
+ d = hidden
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
+
+ def forward(self, x, edge_index, batch):
+ for conv in self.convs:
+ x = conv(x, edge_index).relu()
+ return self.head(global_mean_pool(x, batch))
diff --git a/diag/peptides_depth.py b/diag/peptides_depth.py
new file mode 100644
index 0000000..d751b31
--- /dev/null
+++ b/diag/peptides_depth.py
@@ -0,0 +1,66 @@
+"""Depth-resolution analysis on LRGB Peptides-struct (long-range, real, large graphs).
+
+Best-achievable MSE for a depth-L MPNN == within-(L-round-WL-color) variance of the target.
+The curve over L localizes failure cause WITHOUT training:
+ floor(L=small) high -> under-reached signal still hidden
+ floor(L) - floor(converged) -> H1a: depth-recoverable (more iteration/depth helps)
+ floor(converged) -> H2 : 1-WL ceiling (irreducible by any MPNN)
+Targets are z-scored per dim, so floor is the fraction of variance unexplained (var=1 baseline).
+"""
+import numpy as np
+from collections import defaultdict
+from torch_geometric.datasets import LRGBDataset
+from diag import wl
+
+S = 4000 # graph subsample (for tractable pure-python WL)
+MAX_ROUNDS = 40 # cap (>> typical GIN depth; chains converge near their diameter)
+
+
+def floor_at(ghist, Y):
+ groups = defaultdict(list)
+ for i, c in enumerate(ghist):
+ groups[c].append(i)
+ sse = 0.0
+ for idxs in groups.values():
+ yy = Y[idxs]
+ sse += ((yy - yy.mean(0)) ** 2).sum()
+ return sse / (len(ghist) * Y.shape[1]), len(groups)
+
+
+def main():
+ ds = LRGBDataset(root='/home/yurenh2/rrog/data/lrgb', name='Peptides-struct', split='train')
+ graphs = [ds[i] for i in range(min(S, len(ds)))]
+ fmap = {}
+ def fid(row):
+ t = tuple(row)
+ if t not in fmap:
+ fmap[t] = len(fmap)
+ return fmap[t]
+ adjs, inits, Y = [], [], []
+ for g in graphs:
+ adjs.append(wl.edges_to_adj(g.num_nodes, g.edge_index.numpy()))
+ inits.append(np.array([fid(r) for r in g.x.tolist()], dtype=np.int64))
+ Y.append(g.y.numpy().reshape(-1))
+ Y = np.stack(Y).astype(np.float64)
+ Y = (Y - Y.mean(0)) / (Y.std(0) + 1e-8) # z-score per target
+ print(f"subsample={len(graphs)} graphs, {len(fmap)} distinct node-feature ids, targets={Y.shape[1]}")
+
+ import time; t0 = time.time()
+ node_rounds, ghist_rounds, conv = wl.wl_refine(adjs, inits=inits, max_rounds=MAX_ROUNDS)
+ print(f"WL refined to round {conv} (cap {MAX_ROUNDS}) in {time.time()-t0:.1f}s")
+
+ print(f"{'L':>4} {'floor_MSE(std)':>14} {'%var_unexpl':>12} {'#graph_colors':>14}")
+ floors = {}
+ for L in [0, 1, 2, 3, 4, 5, 8, 16, 32, conv]:
+ r = min(L, conv)
+ f, nc = floor_at(ghist_rounds[r], Y)
+ floors[L] = f
+ print(f"{L:>4} {f:>14.4f} {100*f:>11.1f}% {nc:>14}")
+ h2 = floors[conv]
+ for Lg in [4, 5]:
+ print(f"\nAt GIN depth L={Lg}: H2 ceiling={h2:.3f} | depth-recoverable H1a (floor[{Lg}]-H2)"
+ f"={floors[Lg]-h2:.3f} | already-reachable={1-floors[Lg]:.3f} of var")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/ppgn_color.py b/diag/ppgn_color.py
new file mode 100644
index 0000000..25db01b
--- /dev/null
+++ b/diag/ppgn_color.py
@@ -0,0 +1,209 @@
+"""H2: Recursive PPGN (higher-order, 3-WL) backbone on graph 3-coloring.
+
+State is a pair-tensor X[i,j,:] (not node features). The powerful block multiplies two
+channel-wise n x n matrices (the 3-WL operation): P = M1 @ M2; out = m3([X, P]). Recurse the
+pair-tensor (TRM-style deep supervision), pool to nodes (mean over j), decode colors.
+Self-contained: trains (EMA, best solve), then LE diagnostic + PTRM noise/lambda-select;
+writes color_/le_/ptrm_ JSONs (conv='ppgn') so diag/aggregate.py folds it into the big table.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ppgn_color.py --grad_mode full --seed 0
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diag.train_color import make_split, featurize, OUT
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+
+def dense_A(edge_index, n):
+ A = torch.zeros(n, n)
+ if edge_index.shape[1]:
+ A[edge_index[0], edge_index[1]] = 1.0
+ return torch.maximum(A, A.t())
+
+
+def mlp(di, dh, do):
+ return nn.Sequential(nn.Linear(di, dh), nn.ReLU(), nn.Linear(dh, do))
+
+
+class RecPPGN(nn.Module):
+ def __init__(self, in_dim, hidden=64, k=3, T=3, n_sup=3, grad_mode='full', sigma=0.0):
+ super().__init__()
+ self.node_in = nn.Linear(in_dim, hidden)
+ self.adj_emb = nn.Embedding(2, hidden)
+ self.m1 = mlp(hidden, hidden, hidden); self.m2 = mlp(hidden, hidden, hidden)
+ self.m3 = mlp(2 * hidden, hidden, hidden)
+ self.ln = nn.LayerNorm(hidden)
+ self.head = nn.Linear(hidden, k)
+ self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma
+
+ def init_pair(self, xin, A): # xin [B,n,in], A [B,n,n]
+ h = self.node_in(xin)
+ X = h.unsqueeze(2) + h.unsqueeze(1) # [B,n,n,hidden]
+ return X + self.adj_emb(A.long())
+
+ def block(self, X): # X [B,n,n,h]
+ M1, M2 = self.m1(X), self.m2(X)
+ P = torch.einsum('bikc,bkjc->bijc', M1, M2) / X.shape[1]
+ return self.ln(self.m3(torch.cat([X, P], dim=-1)))
+
+ def _inner(self, X, X0, noise):
+ X = self.block(X + X0)
+ if noise and self.sigma > 0:
+ X = X + self.sigma * torch.randn_like(X)
+ return X
+
+ def recurse(self, X, X0, noise, one_step=False):
+ if one_step:
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ X = self._inner(X, X0, noise)
+ X = X.detach()
+ return self._inner(X, X0, noise)
+ for _ in range(self.T):
+ X = self._inner(X, X0, noise)
+ return X
+
+ def forward(self, xin, A, noise=False):
+ X0 = self.init_pair(xin, A)
+ X = torch.zeros_like(X0)
+ outs = []
+ for s in range(self.n_sup):
+ X = self.recurse(X, X0, noise, one_step=(self.grad_mode == '1step'))
+ outs.append(self.head(X.mean(2))) # pool over j -> [B,n,k]
+ X = X.detach()
+ return outs
+
+
+def conflict_loss(logits, A): # logits [B,n,k], A [B,n,n]
+ p = F.softmax(logits, dim=-1)
+ return (torch.einsum('bik,bjk->bij', p, p) * A).sum() / (A.sum() + 1e-9)
+
+
+def batches(graphs, bs, shuffle, dev):
+ idx = np.arange(len(graphs))
+ if shuffle:
+ np.random.shuffle(idx)
+ for s in range(0, len(idx) - (len(idx) % bs if shuffle else 0), bs):
+ sel = idx[s:s + bs]
+ if len(sel) == 0:
+ continue
+ X = torch.stack([graphs[i]['xin'] for i in sel]).to(dev)
+ A = torch.stack([graphs[i]['A'] for i in sel]).to(dev)
+ yield X, A
+
+
+@torch.no_grad()
+def solve_stats(model, graphs, dev, sample=300):
+ model.eval(); solved = 0; tot = 0
+ for g in graphs[:sample]:
+ A = g['A'].unsqueeze(0).to(dev)
+ col = model(g['xin'].unsqueeze(0).to(dev), A)[-1][0].argmax(-1)
+ conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2
+ solved += int(conf == 0); tot += 1
+ return solved / tot
+
+
+def lyap1_and_solved(model, g, dev, seed, sigma=0.0):
+ gen = torch.Generator(device=dev).manual_seed(seed)
+ xin = g['xin'].unsqueeze(0).to(dev); A = g['A'].unsqueeze(0).to(dev)
+ X0 = model.init_pair(xin, A)
+ X = torch.zeros_like(X0)
+ v = torch.randn(X.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12)
+ lam = 0.0
+ for _ in range(model.n_sup * model.T):
+ Xd, Jv = torch.autograd.functional.jvp(lambda XX: model.block(XX + X0), X, v)
+ nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
+ X = Xd.detach()
+ if sigma > 0:
+ X = X + sigma * torch.randn(X.shape, generator=gen, device=dev)
+ col = model.head(X.mean(2))[0].argmax(-1)
+ conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2
+ return lam / (model.n_sup * model.T), conf
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--bs', type=int, default=16)
+ ap.add_argument('--hidden', type=int, default=64); ap.add_argument('--seed', type=int, default=0)
+ ap.add_argument('--sigma', type=float, default=0.2); ap.add_argument('--K', type=int, default=16)
+ ap.add_argument('--n_graphs', type=int, default=150)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ for g in tr + te:
+ g['A'] = dense_A(g['edge_index'], g['n'])
+ in_dim = tr[0]['xin'].shape[1]
+
+ model = RecPPGN(in_dim, args.hidden, 3, grad_mode=args.grad_mode).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ t0 = time.time(); best = -1; best_state = None
+ for ep in range(args.epochs):
+ model.train()
+ for X, A in batches(tr, args.bs, True, dev):
+ opt.zero_grad()
+ outs = model(X, A, noise=False)
+ loss = sum(conflict_loss(o, A) for o in outs) / len(outs)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr = solve_stats(model, te, dev)
+ if sr > best:
+ best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(bk)
+ print(f"ep{ep+1} solve={sr:.3f}", flush=True)
+ model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval()
+
+ # LE + PTRM (noise + lambda-select) on test
+ lams, fails = [], []
+ passk = lamsel = rand = 0
+ Lr, Sr = [], []
+ for gi, g in enumerate(te[:args.n_graphs]):
+ lam0, c0 = lyap1_and_solved(model, g, dev, seed=gi, sigma=0.0)
+ lams.append(lam0); fails.append(int(c0 > 0))
+ res = [lyap1_and_solved(model, g, dev, seed=1000 * gi + j, sigma=args.sigma) for j in range(args.K)]
+ confs = np.array([c for _, c in res]); rl = np.array([l for l, _ in res])
+ solved = confs == 0
+ passk += int(solved.any()); lamsel += int(solved[rl.argmin()]); rand += int(solved[0])
+ Lr += rl.tolist(); Sr += solved.tolist()
+ lams, fails = np.array(lams), np.array(fails)
+ s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ Lr, Sr = np.array(Lr), np.array(Sr)
+ pauc = (roc_auc_score(Sr.astype(int), -Lr) if roc_auc_score and Sr.any() and (~Sr).any() else float('nan'))
+ n = len(te[:args.n_graphs])
+ print(f"[ppgn/{args.grad_mode}] solve={best:.3f} | LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} | "
+ f"PTRM det={1 - fails.mean():.3f} passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)")
+
+ base = f"ppgn_full_none_n50_k3_p0.2_T3_ns3_s{args.seed}" if args.grad_mode == 'full' \
+ else f"ppgn_1step_none_n50_k3_p0.2_T3_ns3_s{args.seed}"
+ com = {'conv': 'ppgn', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed}
+ json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'), indent=2)
+ json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean()),
+ 'lam_solved': (float(s.mean()) if len(s) else None),
+ 'lam_unsolved': (float(f.mean()) if len(f) else None)},
+ open(os.path.join(OUT, f"le_color_{base}.json"), 'w'), indent=2)
+ json.dump({**com, 'det': 1 - float(fails.mean()),
+ 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n,
+ 'perRoll': float(Sr.mean()), 'auroc': float(pauc)}}},
+ open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'), indent=2)
+ print(" wrote color_/le_color_/ptrm_color_", base)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/ptrm_color.py b/diag/ptrm_color.py
new file mode 100644
index 0000000..4004297
--- /dev/null
+++ b/diag/ptrm_color.py
@@ -0,0 +1,85 @@
+"""Step-2(a): PTRM test-time noise + lambda-based selection on a trained coloring model
+(any backbone feature set via cfg pe). Writes a JSON per ckpt for multi-seed aggregation.
+
+deterministic / pass@K (conflict-min, ground truth) / lambda-select (min lambda1) / random.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_...pt
+"""
+import argparse, json, os
+import numpy as np
+import torch
+from diag.train_color import RecGINColor, make_split, featurize
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+OUT = '/home/yurenh2/rrog/runs'
+
+
+def rollout(model, xin, ei, sigma, n_sup, T, dev, seed):
+ gen = torch.Generator(device=dev).manual_seed(seed)
+ h0 = model.lin_in(xin)
+ z = torch.zeros_like(h0)
+ v = torch.randn(h0.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12)
+ def step(zz):
+ return model.block(zz + h0, ei)
+ lam = 0.0
+ for _ in range(n_sup * T):
+ z_det, Jv = torch.autograd.functional.jvp(step, z, v)
+ nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
+ z = z_det.detach()
+ if sigma > 0:
+ z = z + sigma * torch.randn(z.shape, generator=gen, device=dev)
+ lam /= (n_sup * T)
+ col = model.head(z).argmax(-1)
+ conf = (col[ei[0]] == col[ei[1]]).sum().item() // 2
+ return conf, lam
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--ckpt', required=True)
+ ap.add_argument('--K', type=int, default=16)
+ ap.add_argument('--n_graphs', type=int, default=150)
+ ap.add_argument('--sigmas', type=float, nargs='+', default=[0.05, 0.1, 0.2, 0.4])
+ args = ap.parse_args()
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg']
+ deg = torch.tensor(c['deg']) if c.get('deg') else None
+ model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'],
+ grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev)
+ model.load_state_dict(ck['state']); model.eval()
+ nsup, T = c['n_sup'], c['T']
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), c.get('pe', 'none'), c.get('rwse_k', 16))
+ te = te[:args.n_graphs]; n = len(te)
+
+ det = sum(rollout(model, r['xin'].to(dev), r['edge_index'].to(dev), 0.0, nsup, T, dev, 0)[0] == 0
+ for r in te) / n
+ out = {'conv': c.get('conv', 'gin'), 'pe': c.get('pe', 'none'), 'seed': c.get('seed'),
+ 'grad_mode': c['grad_mode'], 'contract': c.get('contract', False), 'det': det, 'sigmas': {}}
+ print(f"[pe={out['pe']} s{out['seed']}] deterministic solve_rate = {det:.3f} (n={n}, K={args.K})")
+ print(f"{'sigma':>6} {'pass@K':>8} {'lam-sel':>8} {'random':>8} {'perRoll':>8} {'AUROC(s|-lam)':>14}")
+ for sigma in args.sigmas:
+ passk = lamsel = rand = 0
+ L, S = [], []
+ for gi, r in enumerate(te):
+ xin = r['xin'].to(dev); ei = r['edge_index'].to(dev)
+ res = [rollout(model, xin, ei, sigma, nsup, T, dev, 1000 * gi + j) for j in range(args.K)]
+ confs = np.array([c0 for c0, _ in res]); lams = np.array([l for _, l in res])
+ solved = confs == 0
+ passk += int(solved.any()); lamsel += int(solved[lams.argmin()]); rand += int(solved[0])
+ L += lams.tolist(); S += solved.tolist()
+ L, S = np.array(L), np.array(S)
+ auc = (roc_auc_score(S.astype(int), -L) if roc_auc_score and S.any() and (~S).any() else float('nan'))
+ out['sigmas'][str(sigma)] = {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n,
+ 'perRoll': float(S.mean()), 'auroc': float(auc)}
+ print(f"{sigma:>6} {passk/n:>8.3f} {lamsel/n:>8.3f} {rand/n:>8.3f} {S.mean():>8.3f} {auc:>14.3f}")
+
+ base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
+ with open(os.path.join(OUT, f"ptrm_{base}.json"), 'w') as f:
+ json.dump(out, f, indent=2)
+ print(" wrote", os.path.join(OUT, f"ptrm_{base}.json"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/run_archA.sh b/diag/run_archA.sh
new file mode 100644
index 0000000..5385152
--- /dev/null
+++ b/diag/run_archA.sh
@@ -0,0 +1,16 @@
+#!/usr/bin/env bash
+# Conv axis (pe=none): gcn, sage, gat. 5 seeds, train+LE+PTRM. (pin GPU at launch.)
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "A start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ for conv in gcn sage gat; do
+ ck=runs/ckpt_color_${conv}_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "== A s$s conv=$conv =="
+ python3 diag/train_color.py --mode train --conv "$conv" --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train $conv s$s"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le $conv s$s"
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $conv s$s"
+ done
+done
+echo "doneA=$(date -Is)"
diff --git a/diag/run_archB.sh b/diag/run_archB.sh
new file mode 100644
index 0000000..317638f
--- /dev/null
+++ b/diag/run_archB.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env bash
+# GPS transformer backbone + feature axis (gin + lappe / all). 5 seeds. (pin GPU at launch.)
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "B start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ ck=runs/ckpt_color_gps_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "== B s$s conv=gps =="
+ python3 diag/train_color.py --mode train --conv gps --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train gps s$s"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le gps s$s"
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm gps s$s"
+ for pe in lappe all; do
+ ck2=runs/ckpt_color_gin_full_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "== B s$s gin pe=$pe =="
+ python3 diag/train_color.py --mode train --conv gin --pe "$pe" --p 0.2 --epochs 150 --seed "$s" || echo "!! train $pe s$s"
+ python3 diag/train_color.py --mode le --ckpt "$ck2" || echo "!! le $pe s$s"
+ python3 diag/ptrm_color.py --ckpt "$ck2" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $pe s$s"
+ done
+done
+echo "doneB=$(date -Is)"
diff --git a/diag/run_cin.sh b/diag/run_cin.sh
new file mode 100644
index 0000000..6be362d
--- /dev/null
+++ b/diag/run_cin.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+set -uo pipefail
+cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog
+echo "CIN start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ echo "== cin s$s =="; python3 diag/cin_color.py --grad_mode full --seed "$s" || echo "!! cin s$s"
+done
+echo "doneCIN=$(date -Is)"
diff --git a/diag/run_color.sh b/diag/run_color.sh
new file mode 100644
index 0000000..ad3c406
--- /dev/null
+++ b/diag/run_color.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+# Step-2 (TRM regime, large output): graph 3-coloring, TRM full vs HRM 1-step + LE diagnostic.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+for gm in full 1step; do
+ echo "===== train $gm ====="
+ python3 diag/train_color.py --mode train --grad_mode "$gm" --p 0.2 --epochs 150 --seed 0 \
+ || echo "!! train $gm failed"
+done
+echo "===== LE diagnostic (lambda1: solved vs unsolved) ====="
+python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_full_n50_k3_p0.2_T3_ns3_s0.pt || echo "!! le full failed"
+python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_1step_n50_k3_p0.2_T3_ns3_s0.pt || echo "!! le 1step failed"
+echo "done=$(date -Is)"
diff --git a/diag/run_cycle.sh b/diag/run_cycle.sh
new file mode 100644
index 0000000..f85a6fd
--- /dev/null
+++ b/diag/run_cycle.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+# Ring-counting (ZINC, [#5-cycles,#6-cycles]) diagnosis: does a real 1-WL ceiling exist,
+# and does noise (RNI) vs structured >1-WL (RWSE) break it?
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+run() { echo "===== $* ====="; python3 diag/train_cycle.py "$@" --layers 5 --epochs 200 --seed 0 || echo "!! FAILED $*"; }
+run --conv gin --feat none # 1-WL baseline (should fail to count)
+run --conv gcn --feat none # sub-1-WL reference
+run --conv gin --feat rni # noise / PTRM-style crude symmetry break
+run --conv gin --feat rwse # structured >1-WL positive control
+echo "done=$(date -Is)"
diff --git a/diag/run_diag.sh b/diag/run_diag.sh
new file mode 100644
index 0000000..4a6f31a
--- /dev/null
+++ b/diag/run_diag.sh
@@ -0,0 +1,16 @@
+#!/usr/bin/env bash
+# Step-1 diagnosis sweep. CSL = pure-H2 anchor (classification).
+# Triangle counting on ER (graphs 1-WL-distinguishable -> H2~0, the H1 end) and
+# regular (graphs collapse to one WL color -> H2~var, the ceiling end).
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+run() { echo "===== $* ====="; python3 diag/train_diag.py "$@" --layers 4 --epochs 300 --seed 0 \
+ || echo "!! FAILED $*"; }
+for model in gin gcn; do
+ run --task csl --model "$model"
+ run --task tri --model "$model" --kind er
+ run --task tri --model "$model" --kind regular
+done
+echo "done=$(date -Is)"
diff --git a/diag/run_esan.sh b/diag/run_esan.sh
new file mode 100644
index 0000000..f87b379
--- /dev/null
+++ b/diag/run_esan.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+set -uo pipefail
+cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog
+echo "ESAN start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ echo "== esan s$s =="; python3 diag/esan_color.py --grad_mode full --seed "$s" || echo "!! esan s$s"
+done
+echo "doneESAN=$(date -Is)"
diff --git a/diag/run_ff.sh b/diag/run_ff.sh
new file mode 100644
index 0000000..e2b7b0a
--- /dev/null
+++ b/diag/run_ff.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+set -uo pipefail
+cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog
+echo "FF start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for conv in gcn gat; do for L in 4 8 16; do for s in 0 1 2; do
+ python3 diag/ff_color.py --conv "$conv" --L "$L" --seed "$s" || echo "!! ff $conv L$L s$s"
+done; done; done
+echo "doneFF=$(date -Is)"
diff --git a/diag/run_le.sh b/diag/run_le.sh
new file mode 100644
index 0000000..8fc31ea
--- /dev/null
+++ b/diag/run_le.sh
@@ -0,0 +1,12 @@
+#!/usr/bin/env bash
+# Step-2(iii): TRM-ish GIN full-recursion vs 1-step-gradient, then LE diagnostic on each.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 --epochs 200 --seed 0 || echo "!! train full failed"
+python3 diag/train_rec.py --grad_mode 1step --sigma 0 --K 1 --epochs 200 --seed 0 || echo "!! train 1step failed"
+echo "===== LE diagnostic (lambda1: success vs failure) ====="
+python3 diag/lyap.py --ckpt runs/ckpt_rec_full_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap full failed"
+python3 diag/lyap.py --ckpt runs/ckpt_rec_1step_sig0.0_K1_bestq_T3_ns3_s0.pt --n_graphs 300 || echo "!! lyap 1step failed"
+echo "done=$(date -Is)"
diff --git a/diag/run_pe.sh b/diag/run_pe.sh
new file mode 100644
index 0000000..8350160
--- /dev/null
+++ b/diag/run_pe.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+# Roadmap #1: RRoG on PE-augmented backbone. GIN vs GIN+RWSE on coloring:
+# does RRoG-noise add headroom on top of static structural encoding, or is it redundant?
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+for pe in none rwse; do
+ echo "===== train full pe=$pe ====="
+ python3 diag/train_color.py --mode train --grad_mode full --pe "$pe" --p 0.2 --epochs 150 --seed 0 \
+ || echo "!! train $pe failed"
+done
+echo "===== LE (full, both pe) ====="
+for pe in none rwse; do
+ python3 diag/train_color.py --mode le --ckpt runs/ckpt_color_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \
+ || echo "!! le $pe failed"
+done
+echo "===== PTRM noise + lambda-select (both pe) ====="
+for pe in none rwse; do
+ echo "--- pe=$pe ---"
+ python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_${pe}_n50_k3_p0.2_T3_ns3_s0.pt \
+ --K 16 --n_graphs 150 --sigmas 0.1 0.2 0.4 || echo "!! ptrm $pe failed"
+done
+echo "done=$(date -Is)"
diff --git a/diag/run_pe2.sh b/diag/run_pe2.sh
new file mode 100644
index 0000000..db7bd8a
--- /dev/null
+++ b/diag/run_pe2.sh
@@ -0,0 +1,18 @@
+#!/usr/bin/env bash
+# Roadmap #2: RRoG on a GSN-style motif backbone (per-node K3/wedge substructure counts).
+# full-recursion, 5 seeds; train + LE + PTRM(sigma=0.2). Aggregate vs none/rwse.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+for s in 0 1 2 3 4; do
+ ck=runs/ckpt_color_full_gsn_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "===== seed=$s full gsn ====="
+ python3 diag/train_color.py --mode train --grad_mode full --pe gsn --p 0.2 --epochs 150 --seed "$s" \
+ || echo "!! train gsn s$s failed"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le gsn s$s failed"
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm gsn s$s failed"
+done
+echo "===== AGGREGATE (all pe: none/rwse/gsn) ====="
+python3 diag/aggregate.py
+echo "done=$(date -Is)"
diff --git a/diag/run_pe3.sh b/diag/run_pe3.sh
new file mode 100644
index 0000000..a978393
--- /dev/null
+++ b/diag/run_pe3.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+# Roadmap #3 (subgraph/ego features) + #4 (IGNN-style forced contraction), 5 seeds.
+# #3: full --pe sub. #4: full --pe none --contract (vs free full/none baseline).
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+for s in 0 1 2 3 4; do
+ ck=runs/ckpt_color_full_sub_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "===== seed=$s #3 full sub ====="
+ python3 diag/train_color.py --mode train --grad_mode full --pe sub --p 0.2 --epochs 150 --seed "$s" || echo "!! train sub s$s failed"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le sub s$s failed"
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm sub s$s failed"
+
+ ck2=runs/ckpt_color_full_none_ctr_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "===== seed=$s #4 full none --contract ====="
+ python3 diag/train_color.py --mode train --grad_mode full --pe none --contract --p 0.2 --epochs 150 --seed "$s" || echo "!! train ctr s$s failed"
+ python3 diag/train_color.py --mode le --ckpt "$ck2" || echo "!! le ctr s$s failed"
+ python3 diag/ptrm_color.py --ckpt "$ck2" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm ctr s$s failed"
+done
+echo "===== AGGREGATE ====="
+python3 diag/aggregate.py
+echo "done=$(date -Is)"
diff --git a/diag/run_pna.sh b/diag/run_pna.sh
new file mode 100644
index 0000000..897315d
--- /dev/null
+++ b/diag/run_pna.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "PNA start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ ck=runs/ckpt_color_pna_full_none_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "== pna s$s =="
+ python3 diag/train_color.py --mode train --conv pna --pe none --p 0.2 --epochs 150 --seed "$s" || echo "!! train pna s$s"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le pna s$s"
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm pna s$s"
+done
+echo "donePNA=$(date -Is)"
diff --git a/diag/run_ppgn.sh b/diag/run_ppgn.sh
new file mode 100644
index 0000000..2cc27e9
--- /dev/null
+++ b/diag/run_ppgn.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+set -uo pipefail
+cd /home/yurenh2/rrog; export PYTHONPATH=/home/yurenh2/rrog
+echo "PPGN start gpu=${CUDA_VISIBLE_DEVICES:-?} $(date -Is)"
+for s in 0 1 2 3 4; do
+ echo "== ppgn s$s =="; python3 diag/ppgn_color.py --grad_mode full --seed "$s" || echo "!! ppgn s$s"
+done
+echo "donePPGN=$(date -Is)"
diff --git a/diag/run_real.sh b/diag/run_real.sh
new file mode 100644
index 0000000..c29426a
--- /dev/null
+++ b/diag/run_real.sh
@@ -0,0 +1,12 @@
+#!/usr/bin/env bash
+# Peptides-struct training diagnosis: depth sweep + RNI(noise/>1-WL) + GCN(<1-WL) reference.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+run() { echo "===== $* ====="; python3 diag/train_real.py "$@" --epochs 150 --seed 0 || echo "!! FAILED $*"; }
+run --conv gin --layers 5 --rni 0 # baseline 1-WL
+run --conv gin --layers 10 --rni 0 # depth / long-range
+run --conv gin --layers 5 --rni 16 # noise / beyond-1-WL ceiling test
+run --conv gcn --layers 5 --rni 0 # sub-1-WL reference
+echo "done=$(date -Is)"
diff --git a/diag/run_rec.sh b/diag/run_rec.sh
new file mode 100644
index 0000000..632498d
--- /dev/null
+++ b/diag/run_rec.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+# Step-2: recursive GNN + full PTRM on ZINC ring-counting. Does per-step noise + best-Q@K
+# selection break the 1-WL counting ceiling that input-RNI couldn't? Averaging vs selection.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+run() { echo "===== $* ====="; python3 diag/train_rec.py "$@" --epochs 200 --seed 0 || echo "!! FAILED $*"; }
+run --sigma 0 --K 1 --select bestq # deterministic recursive baseline (~1-WL ceiling)
+run --sigma 0.1 --K 8 --select none # averaging over rollouts (RNI-style null)
+run --sigma 0.1 --K 8 --select bestq # PTRM-proper: per-step noise + best-Q@K selection
+run --sigma 0.2 --K 16 --select bestq # scaled noise + rollouts
+echo "done=$(date -Is)"
diff --git a/diag/run_seeds.sh b/diag/run_seeds.sh
new file mode 100644
index 0000000..77f22c8
--- /dev/null
+++ b/diag/run_seeds.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+# Harden coloring results with 5 seeds: full-vs-1step solve, LE AUROC, PTRM(sigma=0.2) for none/rwse.
+set -uo pipefail
+cd /home/yurenh2/rrog
+export PYTHONPATH=/home/yurenh2/rrog
+echo "host=$(hostname) gpu=${CUDA_VISIBLE_DEVICES:-?} start=$(date -Is)"
+for s in 0 1 2 3 4; do
+ for cfg in "full none" "full rwse" "1step none"; do
+ set -- $cfg; gm=$1; pe=$2
+ ck=runs/ckpt_color_${gm}_${pe}_n50_k3_p0.2_T3_ns3_s${s}.pt
+ echo "===== seed=$s $gm $pe ====="
+ python3 diag/train_color.py --mode train --grad_mode "$gm" --pe "$pe" --p 0.2 --epochs 150 --seed "$s" \
+ || echo "!! train $gm $pe s$s failed"
+ python3 diag/train_color.py --mode le --ckpt "$ck" || echo "!! le $gm $pe s$s failed"
+ if [ "$gm" = "full" ]; then
+ python3 diag/ptrm_color.py --ckpt "$ck" --K 16 --n_graphs 150 --sigmas 0.2 || echo "!! ptrm $pe s$s failed"
+ fi
+ done
+done
+echo "===== AGGREGATE ====="
+python3 diag/aggregate.py
+echo "done=$(date -Is)"
diff --git a/diag/selftest_wl.py b/diag/selftest_wl.py
new file mode 100644
index 0000000..6c08310
--- /dev/null
+++ b/diag/selftest_wl.py
@@ -0,0 +1,53 @@
+"""Validate the 1-WL instrument on canonical graphs BEFORE trusting any decomposition.
+Run: PYTHONPATH=/home/yurenh2/rrog python3 /home/yurenh2/rrog/diag/selftest_wl.py
+"""
+import numpy as np
+from diag import wl, datasets
+
+
+def run():
+ P = datasets.canonical_pairs()
+ names = list(P.keys())
+ adjs = [wl.edges_to_adj(P[k]['n'], P[k]['edge_index']) for k in names]
+ node_rounds, ghist_rounds, conv = wl.wl_refine(adjs)
+ color = {names[i]: ghist_rounds[conv][i] for i in range(len(names))}
+ print("converged round:", conv)
+ for k in names:
+ print(f" {k:5s} tri={P[k]['tri']} wlcolor={color[k]}")
+
+ # (1) C6 == 2C3 under 1-WL (both 2-regular) yet differ in triangles -> counting is H2
+ assert color['C6'] == color['2C3'], "C6 vs 2C3 should be 1-WL-equal"
+ assert P['C6']['tri'] != P['2C3']['tri']
+ # (2) P4 != K1,3 (different degree multiset)
+ assert color['P4'] != color['K1,3'], "P4 vs K1,3 should be 1-WL-distinct"
+ print("OK canonical: C6==2C3 (WL blind to triangles), P4!=K1,3")
+
+ # (3) regression H2 floor on {C6,2C3} == target variance
+ sub = [wl.edges_to_adj(P['C6']['n'], P['C6']['edge_index']),
+ wl.edges_to_adj(P['2C3']['n'], P['2C3']['edge_index'])]
+ y = [P['C6']['tri'], P['2C3']['tri']]
+ _, gh, cv = wl.wl_refine(sub)
+ dec = wl.decompose_regression(gh, cv, L=10, y=y, train_idx=[0, 1], eval_idx=[0, 1])
+ print(f" triangle-count H2 floor MSE on {{C6,2C3}} = {dec['mse_floor_oracle_H2']:.4f} "
+ f"(target var = {np.var(y):.4f})")
+ assert abs(dec['mse_floor_oracle_H2'] - np.var(y)) < 1e-9
+
+ # (4) CSL: 4-regular -> 1 node color, 1 graph color, WL-optimal acc = chance (0.1) -> 100% H2
+ csl = datasets.build_csl(n_per_class=15, seed=0)
+ adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in csl]
+ nr, gh, cv = wl.wl_refine(adjs)
+ n_node_colors = len(set(nr[cv][0].tolist()))
+ n_graph_colors = len(set(gh[cv]))
+ y = [d['y'] for d in csl]
+ idx = list(range(len(csl)))
+ att = wl.attribute_classification(gh, cv, L=4, y=y, train_idx=idx, eval_idx=idx)
+ print(f" CSL: node-colors={n_node_colors}, distinct graph-colors={n_graph_colors}, "
+ f"WL-optimal acc={att['wl_optimal_acc_converged']:.3f} (chance 0.1), buckets={att['counts']}")
+ assert n_node_colors == 1 and n_graph_colors == 1
+ assert abs(att['wl_optimal_acc_converged'] - 0.1) < 1e-6
+ assert att['counts'].get('H2', 0) == len(csl)
+ print("OK CSL: fully 1-WL-collapsed -> 100% of failures are H2. Instrument VALIDATED.")
+
+
+if __name__ == "__main__":
+ run()
diff --git a/diag/train_color.py b/diag/train_color.py
new file mode 100644
index 0000000..36f8496
--- /dev/null
+++ b/diag/train_color.py
@@ -0,0 +1,347 @@
+"""Recursive (TRM-ish) GNN graph 3-coloring with swappable BACKBONE for the RRoG roadmap.
+
+--conv gin|gcn|sage|gat|gps : message-passing operator (gps = GraphGPS local MPNN + global
+ attention = TRM's original transformer backbone, on the graph).
+--pe none|rwse|gsn|sub|lappe|all : input structural features (random sym-break [+ encoding]).
+--contract : reverse-flossing lambda-penalty during training (force contraction; roadmap #4).
+--grad_mode full|1step : TRM full recursion vs HRM 1-step gradient.
+Self-supervised conflict/Potts loss; success = zero-conflict; EMA; deep supervision.
+Modes: --mode train (saves ckpt + JSON) / --mode le.
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import GINConv, GCNConv, SAGEConv, GATConv, GPSConv, PNAConv
+from torch_geometric.utils import degree as _pyg_degree
+
+# GPS uses scaled_dot_product_attention; force the MATH kernel so torch.autograd.functional.jvp
+# (LE diagnostic / PTRM rollouts) has a double-backward-able implementation.
+for _f in ('enable_flash_sdp', 'enable_mem_efficient_sdp', 'enable_math_sdp'):
+ try:
+ getattr(torch.backends.cuda, _f)(_f == 'enable_math_sdp')
+ except Exception:
+ pass
+
+OUT = '/home/yurenh2/rrog/runs'
+CACHE = '/home/yurenh2/rrog/data/color_cache'
+
+
+def gen(n, k, p, r, seed):
+ rng = np.random.default_rng(seed)
+ part = rng.integers(0, k, n)
+ src, dst = [], []
+ for i in range(n):
+ for j in range(i + 1, n):
+ if part[i] != part[j] and rng.random() < p:
+ src += [i, j]; dst += [j, i]
+ ei = torch.tensor([src, dst], dtype=torch.long) if src else torch.zeros((2, 0), dtype=torch.long)
+ rf = torch.tensor(rng.standard_normal((n, r)), dtype=torch.float)
+ return {'n': n, 'edge_index': ei, 'rfeat': rf}
+
+
+def make_split(split, n, k, p, r, count, seed0):
+ os.makedirs(CACHE, exist_ok=True)
+ fp = os.path.join(CACHE, f"{split}_n{n}_k{k}_p{p}_r{r}.pt")
+ if os.path.exists(fp):
+ return torch.load(fp, weights_only=False)
+ data = [gen(n, k, p, r, seed0 + i) for i in range(count)]
+ torch.save(data, fp)
+ return data
+
+
+def _adj(edge_index, n):
+ A = np.zeros((n, n), dtype=np.float64)
+ ei = edge_index.numpy()
+ if ei.shape[1]:
+ A[ei[0], ei[1]] = 1.0
+ return np.maximum(A, A.T)
+
+
+def rwse(edge_index, n, K):
+ A = _adj(edge_index, n); deg = A.sum(1)
+ P = A / np.where(deg > 0, deg, 1.0)[:, None]
+ out = np.zeros((n, K), dtype=np.float32); M = np.eye(n)
+ for j in range(K):
+ M = M @ P; out[:, j] = np.diag(M)
+ return torch.from_numpy(out)
+
+
+def gsn_feats(edge_index, n):
+ A = _adj(edge_index, n); deg = A.sum(1)
+ tri = (A @ A @ A).diagonal() / 2.0
+ wedge = deg * (deg - 1) / 2.0
+ return torch.tensor(np.stack([np.log1p(tri), np.log1p(wedge)], axis=1), dtype=torch.float)
+
+
+def sub_feats(edge_index, n):
+ A = _adj(edge_index, n); deg = A.sum(1); A2 = A @ A
+ M = (A2 > 0).astype(np.float64) * (A == 0).astype(np.float64); np.fill_diagonal(M, 0.0)
+ d2 = M.sum(1)
+ tri = (A @ A2).diagonal() / 2.0
+ clus = np.where(deg > 1, tri / (deg * (deg - 1) / 2.0), 0.0)
+ return torch.tensor(np.stack([np.log1p(deg), np.log1p(d2), clus], axis=1), dtype=torch.float)
+
+
+def lappe_feats(edge_index, n, kpe=8):
+ A = _adj(edge_index, n); deg = A.sum(1)
+ di = np.where(deg > 0, 1.0 / np.sqrt(deg), 0.0)
+ L = np.eye(n) - di[:, None] * A * di[None, :]
+ _, V = np.linalg.eigh(L)
+ pe = V[:, 1:kpe + 1]
+ if pe.shape[1] < kpe:
+ pe = np.pad(pe, ((0, 0), (0, kpe - pe.shape[1])))
+ return torch.tensor(pe, dtype=torch.float)
+
+
+def featurize(graphs, pe, rwse_k):
+ def feat(g):
+ ei, n = g['edge_index'], g['n']
+ if pe == 'rwse': return rwse(ei, n, rwse_k)
+ if pe == 'gsn': return gsn_feats(ei, n)
+ if pe == 'sub': return sub_feats(ei, n)
+ if pe == 'lappe': return lappe_feats(ei, n)
+ if pe == 'all': return torch.cat([rwse(ei, n, rwse_k), gsn_feats(ei, n), sub_feats(ei, n)], dim=1)
+ return None
+ for g in graphs:
+ e = feat(g)
+ g['xin'] = torch.cat([g['rfeat'], e], dim=1) if e is not None else g['rfeat']
+ return graphs
+
+
+def deg_hist(graphs):
+ md = 0; ds = []
+ for g in graphs:
+ d = _pyg_degree(g['edge_index'][1], g['n'], dtype=torch.long)
+ md = max(md, int(d.max()) if d.numel() else 0); ds.append(d)
+ h = torch.zeros(md + 1, dtype=torch.long)
+ for d in ds:
+ h += torch.bincount(d, minlength=md + 1)
+ return h
+
+
+def make_conv(conv, hidden, deg=None):
+ if conv == 'gin':
+ return GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ if conv == 'gcn':
+ return GCNConv(hidden, hidden)
+ if conv == 'sage':
+ return SAGEConv(hidden, hidden)
+ if conv == 'gat':
+ return GATConv(hidden, hidden, heads=4, concat=False, add_self_loops=True)
+ if conv == 'pna':
+ return PNAConv(hidden, hidden, aggregators=['mean', 'min', 'max', 'std'],
+ scalers=['identity', 'amplification', 'attenuation'], deg=deg, towers=1)
+ if conv == 'gps':
+ local = GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ return GPSConv(hidden, local, heads=4)
+ raise ValueError(conv)
+
+
+class RecGINColor(nn.Module):
+ def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full', sigma=0.0, conv='gin', deg=None):
+ super().__init__()
+ self.conv_type = conv
+ self.lin_in = nn.Linear(in_dim, hidden)
+ self.convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(inner)])
+ self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)])
+ self.head = nn.Linear(hidden, k)
+ self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma
+
+ def block(self, z, ei, batch=None):
+ if self.conv_type == 'gps' and batch is None:
+ batch = z.new_zeros(z.size(0), dtype=torch.long)
+ for conv, bn in zip(self.convs, self.bns):
+ z = conv(z, ei, batch) if self.conv_type == 'gps' else conv(z, ei)
+ z = bn(z).relu()
+ return z
+
+ def _inner(self, z, h0, ei, noise, batch):
+ z = self.block(z + h0, ei, batch)
+ if noise and self.sigma > 0:
+ z = z + self.sigma * torch.randn_like(z)
+ return z
+
+ def recurse(self, z, h0, ei, noise, batch, one_step=False):
+ if one_step:
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._inner(z, h0, ei, noise, batch)
+ z = z.detach()
+ return self._inner(z, h0, ei, noise, batch)
+ for _ in range(self.T):
+ z = self._inner(z, h0, ei, noise, batch)
+ return z
+
+ def forward(self, xin, ei, batch=None, noise=False):
+ h0 = self.lin_in(xin)
+ z = torch.zeros_like(h0)
+ outs = []
+ for s in range(self.n_sup):
+ z = self.recurse(z, h0, ei, noise, batch, one_step=(self.grad_mode == '1step'))
+ outs.append(self.head(z))
+ z = z.detach()
+ return outs
+
+
+def conflict_loss(logits, ei):
+ p = F.softmax(logits, dim=-1)
+ return (p[ei[0]] * p[ei[1]]).sum(-1).mean()
+
+
+@torch.no_grad()
+def solve_stats(model, recs, dev, sample=None):
+ model.eval()
+ solved = 0; conf = 0.0; tot = 0
+ for r in (recs[:sample] if sample else recs):
+ ei = r['edge_index'].to(dev)
+ col = model(r['xin'].to(dev), ei)[-1].argmax(-1)
+ c = (col[ei[0]] == col[ei[1]]).sum().item() // 2
+ solved += int(c == 0); conf += c; tot += 1
+ return solved / tot, conf / tot
+
+
+def lyap1(model, xin, ei, n_steps, dev, seed=0):
+ g = torch.Generator(device=dev).manual_seed(seed)
+ h0 = model.lin_in(xin).detach()
+ z = torch.zeros_like(h0)
+ v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
+ def step_fn(zz):
+ return model.block(zz + h0, ei)
+ lam = 0.0
+ for _ in range(n_steps):
+ z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
+ z = z_next.detach(); nv = Jv.norm()
+ lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
+ return lam / n_steps
+
+
+def run_le(model, recs, dev, n_steps, n_graphs=300):
+ try:
+ from sklearn.metrics import roc_auc_score
+ except Exception:
+ roc_auc_score = None
+ model.eval()
+ lams, fails = [], []
+ for i, r in enumerate(recs[:n_graphs]):
+ ei = r['edge_index'].to(dev); xin = r['xin'].to(dev)
+ with torch.no_grad():
+ col = model(xin, ei)[-1].argmax(-1)
+ c = (col[ei[0]] == col[ei[1]]).sum().item()
+ fails.append(int(c > 0))
+ lams.append(lyap1(model, xin, ei, n_steps, dev, seed=i))
+ lams, fails = np.array(lams), np.array(fails)
+ s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ sep = (f.mean() - s.mean()) if len(s) and len(f) else float('nan')
+ print(f"[{model.conv_type}/{model.grad_mode}] LE n={len(lams)} fail={fails.mean():.2f} | "
+ f"SOLVED {s.mean() if len(s) else float('nan'):+.3f} UNSOLVED {f.mean() if len(f) else float('nan'):+.3f}"
+ f" sep={sep:+.3f} AUROC={auc:.3f} mean_lam={lams.mean():+.3f}")
+ return {'n': int(len(lams)), 'fail_rate': float(fails.mean()), 'auroc': float(auc), 'sep': float(sep),
+ 'lam_solved': (float(s.mean()) if len(s) else None),
+ 'lam_unsolved': (float(f.mean()) if len(f) else None), 'mean_lam': float(lams.mean())}
+
+
+def lyap_penalty(model, x, ei, batch, target=-0.5):
+ h0 = model.lin_in(x)
+ with torch.no_grad():
+ zr = model.recurse(torch.zeros_like(h0), h0.detach(), ei, False, batch)
+ v = torch.randn_like(zr); v = v / (v.norm() + 1e-12)
+ _, Jv = torch.autograd.functional.jvp(lambda zz: model.block(zz + h0, ei, batch), zr, v, create_graph=True)
+ return (torch.log(Jv.norm() + 1e-12) - target) ** 2
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--mode', choices=['train', 'le'], default='train')
+ ap.add_argument('--conv', choices=['gin', 'gcn', 'sage', 'gat', 'pna', 'gps'], default='gin')
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--pe', choices=['none', 'rwse', 'gsn', 'sub', 'lappe', 'all'], default='none')
+ ap.add_argument('--contract', action='store_true')
+ ap.add_argument('--rwse_k', type=int, default=16)
+ ap.add_argument('--ckpt', default=None)
+ ap.add_argument('--n', type=int, default=50); ap.add_argument('--k', type=int, default=3)
+ ap.add_argument('--p', type=float, default=0.2); ap.add_argument('--r', type=int, default=8)
+ ap.add_argument('--hidden', type=int, default=128); ap.add_argument('--T', type=int, default=3)
+ ap.add_argument('--n_sup', type=int, default=3); ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--bs', type=int, default=32)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ if args.mode == 'le':
+ ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg']
+ te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000),
+ c.get('pe', 'none'), c.get('rwse_k', 16))
+ deg = torch.tensor(c['deg']) if c.get('deg') else None
+ model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'],
+ grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev)
+ model.load_state_dict(ck['state']); model.eval()
+ res = run_le(model, te, dev, c['n_sup'] * c['T'])
+ base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
+ with open(os.path.join(OUT, f"le_{base}.json"), 'w') as fjs:
+ json.dump({'conv': c.get('conv', 'gin'), 'grad_mode': c['grad_mode'], 'pe': c.get('pe', 'none'),
+ 'contract': c.get('contract', False), 'seed': c.get('seed'), **res}, fjs, indent=2)
+ return
+
+ te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), args.pe, args.rwse_k)
+ tr = featurize(make_split('train', args.n, args.k, args.p, args.r, 2000, 0), args.pe, args.rwse_k)
+ in_dim = tr[0]['xin'].shape[1]
+ data = [Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr]
+ trl = DataLoader(data, batch_size=args.bs, shuffle=True, drop_last=True)
+ deg = deg_hist(tr) if args.conv == 'pna' else None
+ model = RecGINColor(in_dim, args.hidden, args.k, args.T, args.n_sup,
+ grad_mode=args.grad_mode, conv=args.conv, deg=deg).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ t0 = time.time(); best_solve = -1; best = {}; best_state = None
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ outs = model(b.x, b.edge_index, b.batch, noise=False)
+ loss = sum(conflict_loss(o, b.edge_index) for o in outs) / len(outs)
+ if args.contract:
+ loss = loss + lyap_penalty(model, b.x, b.edge_index, b.batch)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ if torch.is_floating_point(v):
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001)
+ else:
+ ema[kk].copy_(v.detach())
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ backup = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema)
+ sr, mc = solve_stats(model, te, dev, sample=300)
+ if sr > best_solve:
+ best_solve = sr; best = {'ep': ep + 1, 'solve_rate': round(sr, 4), 'mean_conflicts': round(mc, 3)}
+ best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(backup)
+ print(f"ep{ep+1} solve_rate={sr:.3f} mean_conflicts={mc:.2f}", flush=True)
+
+ sfx = ('_ctr' if args.contract else '')
+ tag = f"color_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}"
+ rep = {'task': 'graph3coloring', 'tag': tag, **vars(args), 'in_dim': in_dim,
+ 'sec': round(time.time() - t0, 1), **best}
+ print(f"[{tag}] best solve_rate={best.get('solve_rate')} @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ torch.save({'state': best_state, 'cfg': {'in_dim': in_dim, 'hidden': args.hidden, 'k': args.k,
+ 'T': args.T, 'n_sup': args.n_sup, 'grad_mode': args.grad_mode, 'pe': args.pe,
+ 'rwse_k': args.rwse_k, 'contract': args.contract, 'conv': args.conv, 'seed': args.seed,
+ 'deg': (deg.tolist() if deg is not None else None)}},
+ os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/train_cycle.py b/diag/train_cycle.py
new file mode 100644
index 0000000..d2342f3
--- /dev/null
+++ b/diag/train_cycle.py
@@ -0,0 +1,183 @@
+"""GNN-native ring-counting on real molecules (ZINC): regress [#5-cycles, #6-cycles].
+
+k-cycle counts (k>=3) are provably NOT computable by 1-WL/MPNN (Chen et al. 2020) -> a REAL
+H2 ceiling on REAL graphs. Training-based diagnosis (partition instrument is vacuous on
+feature-rich graphs):
+ GIN(L) 1-WL baseline -> should FAIL to count
+ GCN(L) sub-1-WL reference
+ GIN+RNI random feats = NOISE -> PTRM-style crude symmetry break (eval-averaged)
+ GIN+RWSE random-walk return probs-> structured >1-WL positive control
+Reads: GIN high error + RWSE fixes it = real ceiling exists; RNI also fixes = crude noise
+breaks it (bridge cashed); only RWSE = bridge needs STRUCTURED stochasticity (GRAM>PTRM).
+Targets z-scored for training; per-target MAE reported in RAW ring units.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_cycle.py --conv gin --feat none
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import networkx as nx
+from torch_geometric.datasets import ZINC
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from torch_geometric.utils import to_networkx
+from torch_geometric.nn import GINConv, GCNConv, global_add_pool
+
+ROOT = '/home/yurenh2/rrog/data/zinc'
+CACHE = '/home/yurenh2/rrog/data/cycle_cache'
+OUT = '/home/yurenh2/rrog/runs'
+RWSE_K = 16
+
+
+def rwse(edge_index, n, K=RWSE_K):
+ A = np.zeros((n, n), dtype=np.float64)
+ ei = edge_index.numpy()
+ A[ei[0], ei[1]] = 1.0
+ A = np.maximum(A, A.T)
+ deg = A.sum(1)
+ P = A / np.where(deg > 0, deg, 1.0)[:, None]
+ out = np.zeros((n, K), dtype=np.float32)
+ M = np.eye(n)
+ for k in range(K):
+ M = M @ P
+ out[:, k] = np.diag(M)
+ return torch.from_numpy(out)
+
+
+def c56(data):
+ G = to_networkx(data, to_undirected=True)
+ c = {5: 0, 6: 0}
+ for cyc in nx.simple_cycles(G, length_bound=6):
+ L = len(cyc)
+ if L in c:
+ c[L] += 1
+ return [float(c[5]), float(c[6])]
+
+
+def prepare(split):
+ os.makedirs(CACHE, exist_ok=True)
+ fp = os.path.join(CACHE, f"{split}.pt")
+ if os.path.exists(fp):
+ return torch.load(fp, weights_only=False)
+ ds = ZINC(ROOT, subset=True, split=split)
+ out = []
+ for g in ds:
+ out.append({'x': g.x.view(-1).long(), 'edge_index': g.edge_index,
+ 'rwse': rwse(g.edge_index, g.num_nodes),
+ 'y': torch.tensor(c56(g), dtype=torch.float)})
+ torch.save(out, fp)
+ return out
+
+
+class Net(nn.Module):
+ def __init__(self, n_atom, hidden, layers, conv='gin', rni=0, use_rwse=False):
+ super().__init__()
+ self.emb = nn.Embedding(n_atom, hidden)
+ self.rni, self.use_rwse = rni, use_rwse
+ din = hidden + rni + (RWSE_K if use_rwse else 0)
+ self.lin_in = nn.Linear(din, hidden)
+ self.convs, self.bns = nn.ModuleList(), nn.ModuleList()
+ for _ in range(layers):
+ if conv == 'gin':
+ self.convs.append(GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True))
+ else:
+ self.convs.append(GCNConv(hidden, hidden))
+ self.bns.append(nn.BatchNorm1d(hidden))
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
+
+ def forward(self, x, edge_index, batch, rwse=None):
+ h = self.emb(x)
+ parts = [h]
+ if self.use_rwse:
+ parts.append(rwse)
+ if self.rni:
+ parts.append(torch.randn(h.size(0), self.rni, device=h.device))
+ h = self.lin_in(torch.cat(parts, dim=1))
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, edge_index)).relu()
+ return self.head(global_add_pool(h, batch))
+
+
+def to_loader(recs, bs, shuffle, drop_last=False):
+ data = [Data(x=r['x'], edge_index=r['edge_index'], rwse=r['rwse'],
+ y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs]
+ return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+
+
+@torch.no_grad()
+def eval_mae(model, loader, dev, ymu, ysd, samples=1):
+ model.eval(); abs_err = torch.zeros(2); n = 0
+ for b in loader:
+ b = b.to(dev)
+ ps = torch.stack([model(b.x, b.edge_index, b.batch, b.rwse) for _ in range(samples)]).mean(0)
+ pr = ps * ysd.to(dev) + ymu.to(dev) # un-standardize -> raw ring units
+ yr = b.y * ysd.to(dev) + ymu.to(dev)
+ abs_err += (pr - yr).abs().sum(0).cpu(); n += b.num_graphs
+ return (abs_err / n).tolist()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--feat', choices=['none', 'rni', 'rwse'], default='none')
+ ap.add_argument('--layers', type=int, default=5)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--rni_dim', type=int, default=16)
+ ap.add_argument('--epochs', type=int, default=200)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr, va, te = prepare('train'), prepare('val'), prepare('test')
+ n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
+ Ytr = torch.stack([r['y'] for r in tr])
+ ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+ for recs in (tr, va, te):
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+
+ rni = args.rni_dim if args.feat == 'rni' else 0
+ use_rwse = args.feat == 'rwse'
+ samples = 8 if rni else 1
+ model = Net(n_atom, args.hidden, args.layers, args.conv, rni, use_rwse).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ lossf = nn.L1Loss()
+ trl = to_loader(tr, args.bs, True, drop_last=True)
+ trl_e, val, tel = to_loader(tr, 256, False), to_loader(va, 256, False), to_loader(te, 256, False)
+
+ t0 = time.time(); best_val = 9e9; best = {}
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = lossf(model(b.x, b.edge_index, b.batch, b.rwse), b.y)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ vm = eval_mae(model, val, dev, ymu, ysd, samples)
+ if sum(vm) < best_val:
+ best_val = sum(vm)
+ best = {'ep': ep + 1, 'train_mae': eval_mae(model, trl_e, dev, ymu, ysd, samples),
+ 'val_mae': vm, 'test_mae': eval_mae(model, tel, dev, ymu, ysd, samples)}
+ print(f"ep{ep+1} val_mae(c5,c6)={[round(x,3) for x in vm]}", flush=True)
+
+ tag = f"{args.conv}_{args.feat}_L{args.layers}_s{args.seed}"
+ rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args),
+ 'y_std_raw': ysd.tolist(), 'sec': round(time.time() - t0, 1), 'dev': dev, **best}
+ tm = best.get('test_mae'); trm = best.get('train_mae')
+ print(f"[{tag}] train_mae(c5,c6)={[round(x,3) for x in trm]} test_mae={[round(x,3) for x in tm]} "
+ f"(raw rings; std={ [round(x,2) for x in ysd.tolist()] }) @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"cyc_{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", os.path.join(OUT, f"cyc_{tag}.json"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/train_diag.py b/diag/train_diag.py
new file mode 100644
index 0000000..a9c4ab2
--- /dev/null
+++ b/diag/train_diag.py
@@ -0,0 +1,161 @@
+"""Train a backbone, collect failures, attribute them via the 1-WL instrument.
+
+Node features are CONSTANT (all-ones) so the GIN starts from anonymous nodes -> its
+expressivity ceiling is exactly the anonymous 1-WL partition the instrument computes
+(wl_refine init = all-zero). GIN depth L == L WL rounds. Regression targets are
+standardized (train stats) for stable training; all reported MSEs are in original units.
+Train AND test metrics are reported so non-H2 error can be split into optimization
+(can't even fit train) vs generalization (fits train, fails test).
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_diag.py --task csl --model gin
+"""
+import argparse, json, os, time
+from collections import Counter
+import numpy as np
+import torch
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from diag import wl, datasets as DS, models as M
+
+
+def to_pyg(raw, task, ymu=0.0, ysd=1.0):
+ out = []
+ for d in raw:
+ x = torch.ones(d['n'], 1)
+ ei = torch.tensor(d['edge_index'], dtype=torch.long)
+ if task == 'clf':
+ y = torch.tensor([d['y']], dtype=torch.long)
+ else:
+ y = torch.tensor([[(d['y'] - ymu) / ysd]], dtype=torch.float)
+ out.append(Data(x=x, edge_index=ei, y=y, num_nodes=d['n']))
+ return out
+
+
+def split(n, frac, seed, y=None, stratify=False):
+ rng = np.random.default_rng(seed)
+ idx = np.arange(n)
+ if stratify and y is not None:
+ y = np.asarray(y); test = []
+ for c in np.unique(y):
+ ci = idx[y == c]; rng.shuffle(ci)
+ test += ci[:max(1, int(round(frac * len(ci))))].tolist()
+ test = sorted(set(test)); train = [i for i in idx.tolist() if i not in set(test)]
+ else:
+ rng.shuffle(idx); k = int(frac * n)
+ test = sorted(idx[:k].tolist()); train = sorted(idx[k:].tolist())
+ return train, test
+
+
+@torch.no_grad()
+def predict(model, loader, task, dev):
+ model.eval(); outs = []
+ for b in loader:
+ b = b.to(dev)
+ o = model(b.x, b.edge_index, b.batch)
+ outs.append((o.argmax(1) if task == 'clf' else o.view(-1)).cpu())
+ return torch.cat(outs).numpy()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--task', choices=['csl', 'tri'], required=True)
+ ap.add_argument('--model', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--layers', type=int, default=4)
+ ap.add_argument('--hidden', type=int, default=64)
+ ap.add_argument('--epochs', type=int, default=300)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--seed', type=int, default=0)
+ ap.add_argument('--kind', default='er')
+ ap.add_argument('--out', default='/home/yurenh2/rrog/runs')
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(args.out, exist_ok=True)
+
+ if args.task == 'csl':
+ raw = DS.build_csl(n_per_class=15, seed=args.seed); task, out_dim = 'clf', 10
+ y = [d['y'] for d in raw]; tr, te = split(len(raw), 0.34, args.seed, y, stratify=True)
+ else:
+ raw = DS.build_triangle_count(n_graphs=800, n_nodes=18, kind=args.kind, deg=3, seed=args.seed)
+ task, out_dim = 'reg', 1; tr, te = split(len(raw), 0.3, args.seed)
+
+ ymu, ysd = 0.0, 1.0
+ if task == 'reg':
+ ytr = np.array([raw[i]['y'] for i in tr], dtype=np.float64)
+ ymu, ysd = float(ytr.mean()), float(ytr.std() + 1e-8)
+
+ pyg = to_pyg(raw, task, ymu, ysd)
+ trl = DataLoader([pyg[i] for i in tr], batch_size=32, shuffle=True, drop_last=True)
+ alll = DataLoader(pyg, batch_size=64)
+
+ Model = M.GIN if args.model == 'gin' else M.GCN
+ model = Model(in_dim=1, hidden=args.hidden, layers=args.layers, out_dim=out_dim).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3)
+ lossf = torch.nn.CrossEntropyLoss() if task == 'clf' else torch.nn.MSELoss()
+
+ t0 = time.time()
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ o = model(b.x, b.edge_index, b.batch)
+ loss = lossf(o, b.y)
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ pred = predict(model, alll, task, dev)
+ if task == 'reg':
+ pred = pred * ysd + ymu
+ yv = np.array([d['y'] for d in raw], dtype=(np.float64 if task == 'reg' else np.int64))
+ adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in raw]
+ _, ghist, conv = wl.wl_refine(adjs)
+
+ rep = {'task': args.task, 'model': args.model, 'layers': args.layers, 'seed': args.seed,
+ 'kind': (args.kind if args.task == 'tri' else None),
+ 'n': len(raw), 'n_train': len(tr), 'n_test': len(te), 'conv_round': conv,
+ 'sec': round(time.time() - t0, 1), 'dev': dev}
+
+ if task == 'clf':
+ test_pred, test_y = pred[te], yv[te]
+ acc = float((test_pred == test_y).mean())
+ train_acc = float((pred[tr] == yv[tr]).mean())
+ att = wl.attribute_classification(ghist, conv, args.layers, yv, tr, te)
+ fails = [te[k] for k in range(len(te)) if test_pred[k] != test_y[k]]
+ fb = Counter(att['buckets'][i] for i in fails)
+ rep.update({'train_acc': round(train_acc, 4), 'test_acc': round(acc, 4),
+ 'wl_ceiling_acc_converged': round(att['wl_optimal_acc_converged'], 4),
+ 'wl_ceiling_acc_Ldepth': round(att['wl_optimal_acc_Ldepth'], 4),
+ 'test_bucket_counts': att['counts'],
+ 'failure_bucket_counts': dict(fb), 'n_failures': len(fails)})
+ print(f"[{args.task}/{args.model}] train_acc={train_acc:.3f} test_acc={acc:.3f} | "
+ f"1-WL ceiling(conv)={att['wl_optimal_acc_converged']:.3f} "
+ f"L-depth={att['wl_optimal_acc_Ldepth']:.3f} | failures={len(fails)} -> {dict(fb)}")
+ else:
+ test_pred, test_y = pred[te], yv[te]
+ mse = float(((test_pred - test_y) ** 2).mean())
+ train_mse = float(((pred[tr] - yv[tr]) ** 2).mean())
+ dec = wl.decompose_regression(ghist, conv, args.layers, yv, tr, te)
+ h2 = dec['mse_floor_oracle_H2']
+ rep.update({'train_mse': round(train_mse, 4), 'test_mse_gin': round(mse, 4),
+ 'mse_floor_oracle_H2': round(h2, 4),
+ 'mse_floor_converged_train': round(dec['mse_floor_converged_train'], 4),
+ 'mse_floor_Ldepth_train': round(dec['mse_floor_Ldepth_train'], 4),
+ 'var_target_test': round(dec['var_target_eval'], 4),
+ 'frac_test_unseen_color': round(dec['frac_test_unseen_color'], 4),
+ 'frac_test_singleton_color': round(dec['frac_test_singleton_color'], 4),
+ 'learn_gap_test': round(max(0.0, mse - h2), 4)})
+ print(f"[{args.task}/{args.model}/{args.kind}] train_mse={train_mse:.3f} test_mse={mse:.3f} | "
+ f"1-WL oracle floor(H2)={h2:.3f} | unseen={dec['frac_test_unseen_color']:.2f} "
+ f"singleton={dec['frac_test_singleton_color']:.2f} | learn_gap={max(0.0, mse - h2):.3f} "
+ f"var_y={dec['var_target_eval']:.3f}")
+
+ tag = f"{args.task}_{args.kind}" if args.task == 'tri' else args.task
+ fn = os.path.join(args.out, f"diag_{tag}_{args.model}_L{args.layers}_s{args.seed}.json")
+ with open(fn, 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", fn)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/train_real.py b/diag/train_real.py
new file mode 100644
index 0000000..336d86f
--- /dev/null
+++ b/diag/train_real.py
@@ -0,0 +1,139 @@
+"""Training-based failure diagnosis on LRGB Peptides-struct (real, large, long-range).
+
+The WL partition instrument is vacuous here (graphs ~all distinguishable), so we diagnose
+by TRAINING and comparing:
+ GIN(L) : standard 1-WL backbone at depth L
+ GIN(L)+RNI : random node features = noise = beyond-1-WL symmetry breaker
+ GCN(L) : sub-1-WL reference
+Reads: deeper helps -> long-range/under-reaching; RNI helps -> a real >1-WL ceiling that
+noise breaks; train<<test -> generalization; train high -> compute/optimization ceiling.
+Targets z-scored per dim; metric = standardized MAE (lower better). 11 targets.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_real.py --conv gin --layers 5 --rni 0
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+from torch_geometric.datasets import LRGBDataset
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import GINConv, GCNConv, global_mean_pool
+
+ROOT = '/home/yurenh2/rrog/data/lrgb'
+OUT = '/home/yurenh2/rrog/runs'
+
+
+class Net(nn.Module):
+ def __init__(self, col_sizes, hidden, layers, out_dim, conv='gin', rni=0):
+ super().__init__()
+ self.embs = nn.ModuleList([nn.Embedding(int(s), hidden) for s in col_sizes])
+ self.rni = rni
+ self.lin_in = nn.Linear(hidden + rni, hidden)
+ self.convs, self.bns = nn.ModuleList(), nn.ModuleList()
+ for _ in range(layers):
+ if conv == 'gin':
+ mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden))
+ self.convs.append(GINConv(mlp, train_eps=True))
+ else:
+ self.convs.append(GCNConv(hidden, hidden))
+ self.bns.append(nn.BatchNorm1d(hidden))
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
+
+ def forward(self, x, edge_index, batch):
+ h = sum(emb(x[:, i]) for i, emb in enumerate(self.embs))
+ if self.rni:
+ h = torch.cat([h, torch.randn(h.size(0), self.rni, device=h.device)], dim=1)
+ h = self.lin_in(h)
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, edge_index)).relu()
+ return self.head(global_mean_pool(h, batch))
+
+
+@torch.no_grad()
+def mae(model, loader, dev, ymu, ysd):
+ model.eval(); se = n = 0.0
+ for b in loader:
+ b = b.to(dev)
+ o = model(b.x, b.edge_index, b.batch)
+ se += (o - b.y).abs().sum().item(); n += b.y.numel()
+ return se / n
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--layers', type=int, default=5)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--rni', type=int, default=0)
+ ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr = LRGBDataset(root=ROOT, name='Peptides-struct', split='train')
+ va = LRGBDataset(root=ROOT, name='Peptides-struct', split='val')
+ te = LRGBDataset(root=ROOT, name='Peptides-struct', split='test')
+
+ # per-column embedding sizes + target standardization (train stats)
+ col_max = None
+ Ytr = []
+ for g in tr:
+ m = g.x.max(0).values
+ col_max = m if col_max is None else torch.maximum(col_max, m)
+ Ytr.append(g.y.view(-1))
+ for ds in (va, te):
+ for g in ds:
+ col_max = torch.maximum(col_max, g.x.max(0).values)
+ col_sizes = (col_max + 2).tolist()
+ Ytr = torch.stack(Ytr)
+ ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+
+ def norm(ds):
+ out = []
+ for g in ds:
+ g = g.clone(); g.y = (g.y.view(1, -1) - ymu) / ysd
+ out.append(g)
+ return out
+ trl = DataLoader(norm(tr), batch_size=args.bs, shuffle=True, drop_last=True)
+ val = DataLoader(norm(va), batch_size=256)
+ tel = DataLoader(norm(te), batch_size=256)
+ trl_eval = DataLoader(norm(tr), batch_size=256)
+
+ model = Net(col_sizes, args.hidden, args.layers, out_dim=11, conv=args.conv, rni=args.rni).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ lossf = nn.L1Loss()
+
+ t0 = time.time(); best_val = 9e9; best = {}
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = lossf(model(b.x, b.edge_index, b.batch), b.y)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 15 == 0 or ep == args.epochs - 1:
+ vm = mae(model, val, dev, ymu, ysd)
+ if vm < best_val:
+ best_val = vm
+ best = {'ep': ep + 1, 'train_mae': mae(model, trl_eval, dev, ymu, ysd),
+ 'val_mae': vm, 'test_mae': mae(model, tel, dev, ymu, ysd)}
+ print(f"ep{ep+1} val_mae={vm:.4f}", flush=True)
+
+ tag = f"{args.conv}_L{args.layers}_rni{args.rni}_s{args.seed}"
+ rep = {'dataset': 'Peptides-struct', 'tag': tag, **vars(args),
+ 'sec': round(time.time() - t0, 1), 'dev': dev, **best}
+ print(f"[{tag}] train_mae={best.get('train_mae'):.4f} val_mae={best.get('val_mae'):.4f} "
+ f"test_mae={best.get('test_mae'):.4f} @ep{best.get('ep')} ({rep['sec']}s)")
+ fn = os.path.join(OUT, f"real_{tag}.json")
+ with open(fn, 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", fn)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/train_rec.py b/diag/train_rec.py
new file mode 100644
index 0000000..9866f28
--- /dev/null
+++ b/diag/train_rec.py
@@ -0,0 +1,174 @@
+"""Step-2: Recursive (TRM-ish) GNN on ZINC ring-counting + optional PTRM noise/selection.
+
+Recurrent shared-weight GIN block, deep-supervised over n_sup steps (TRM-style: carry latent
+detached between steps). --grad_mode controls the LAST supervision step's recursion:
+ full : backprop through all T inner recursions (TRM)
+ 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient)
+Optional per-step Gaussian noise (sigma) + K stochastic rollouts selected by a value head
+(best-Q@K) for the PTRM experiments. Saves a checkpoint for the LE diagnostic (diag/lyap.py).
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+from torch_geometric.loader import DataLoader
+from torch_geometric.data import Data
+from torch_geometric.nn import GINConv, global_add_pool
+from diag.train_cycle import prepare
+
+OUT = '/home/yurenh2/rrog/runs'
+
+
+def loader(recs, bs, shuffle, drop_last=False):
+ data = [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2),
+ num_nodes=r['x'].numel()) for r in recs]
+ return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+
+
+class RecGIN(nn.Module):
+ def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, grad_mode='full'):
+ super().__init__()
+ self.emb = nn.Embedding(n_atom, hidden)
+ self.convs = nn.ModuleList([GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ for _ in range(inner)])
+ self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)])
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
+ self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
+ self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode
+
+ def block(self, z, ei):
+ for conv, bn in zip(self.convs, self.bns):
+ z = bn(conv(z, ei)).relu()
+ return z
+
+ def _inner(self, z, h0, ei, noise):
+ z = self.block(z + h0, ei)
+ if noise and self.sigma > 0:
+ z = z + self.sigma * torch.randn_like(z)
+ return z
+
+ def recurse(self, z, h0, ei, noise, one_step=False):
+ if one_step: # HRM 1-step gradient
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._inner(z, h0, ei, noise)
+ z = z.detach()
+ return self._inner(z, h0, ei, noise) # only last inner carries grad
+ for _ in range(self.T): # TRM full recursion
+ z = self._inner(z, h0, ei, noise)
+ return z
+
+ def forward(self, x, ei, batch, noise=False):
+ h0 = self.emb(x)
+ z = torch.zeros_like(h0)
+ preds = []
+ for s in range(self.n_sup):
+ if s < self.n_sup - 1:
+ with torch.no_grad():
+ z = self.recurse(z, h0, ei, noise)
+ z = z.detach()
+ else:
+ z = self.recurse(z, h0, ei, noise, one_step=(self.grad_mode == '1step'))
+ preds.append(self.head(global_add_pool(z, batch)))
+ q = self.qhead(global_add_pool(z, batch)).view(-1)
+ return preds, q
+
+
+@torch.no_grad()
+def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'):
+ model.eval()
+ ysd_d, ymu_d = ysd.to(dev), ymu.to(dev)
+ ae = torch.zeros(2); ae_or = torch.zeros(2); n = 0
+ for b in ld:
+ b = b.to(dev)
+ if K == 1:
+ preds, _ = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ chosen = oracle = preds[-1]
+ else:
+ P, Q = [], []
+ for _ in range(K):
+ preds, q = model(b.x, b.edge_index, b.batch, noise=True)
+ P.append(preds[-1]); Q.append(q)
+ P = torch.stack(P); Q = torch.stack(Q)
+ ar = torch.arange(P.size(1), device=dev)
+ chosen = P[Q.argmax(0), ar] if select == 'bestq' else P.mean(0)
+ oracle = P[(P - b.y.unsqueeze(0)).abs().sum(-1).argmin(0), ar]
+ ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ ae_or += ((oracle * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ n += b.num_graphs
+ return (ae / n).tolist(), (ae_or / n).tolist()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--sigma', type=float, default=0.0)
+ ap.add_argument('--K', type=int, default=1)
+ ap.add_argument('--select', choices=['none', 'bestq'], default='bestq')
+ ap.add_argument('--T', type=int, default=3)
+ ap.add_argument('--n_sup', type=int, default=3)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--epochs', type=int, default=200)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--lam_q', type=float, default=1.0)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr, va, te = prepare('train'), prepare('val'), prepare('test')
+ n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
+ Ytr = torch.stack([r['y'] for r in tr]); ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+ for recs in (tr, va, te):
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+ trl = loader(tr, args.bs, True, drop_last=True)
+ val, tel = loader(va, 256, False), loader(te, 256, False)
+
+ model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ l1 = nn.L1Loss()
+
+ t0 = time.time(); best_val = 9e9; best = {}; best_state = None
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ loss = sum(l1(p, b.y) for p in preds) / len(preds)
+ with torch.no_grad():
+ tq = -(preds[-1] - b.y).abs().mean(1)
+ loss = loss + args.lam_q * nn.functional.mse_loss(q, tq)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select)
+ if sum(vm) < best_val:
+ best_val = sum(vm)
+ tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select)
+ best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo}
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
+ print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True)
+
+ tag = f"rec_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}_s{args.seed}"
+ rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1),
+ 'dev': dev, 'y_std_raw': ysd.tolist(), **best}
+ print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} "
+ f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ torch.save({'state': best_state or model.state_dict(),
+ 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup,
+ 'sigma': args.sigma, 'grad_mode': args.grad_mode},
+ 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/wl.py b/diag/wl.py
new file mode 100644
index 0000000..26ab0a3
--- /dev/null
+++ b/diag/wl.py
@@ -0,0 +1,166 @@
+"""1-WL color-refinement instrument for diagnosing GNN failures (H1 vs H2).
+
+A GIN with L layers == L rounds of 1-WL refinement (injective sum aggregation).
+A failure on sample i is attributed by label purity of its WL color classes:
+
+ converged-WL class IMPURE (train labels conflict under same color)
+ -> H2 : 1-WL ceiling. No MPNN at ANY depth separates -> needs >1-WL (noise).
+ converged pure, but L-round class impure
+ -> H1a_depth : separable only with MORE rounds -> deterministic RR-on-graph / depth helps.
+ L-round class pure (info present at depth L) but model wrong
+ -> H1b_opt : optimization / capacity. Train better.
+
+Refinement is dataset-global (shared per-round signature->label map) so node colors and
+graph-color histograms are comparable across graphs.
+"""
+from collections import Counter, defaultdict
+import numpy as np
+
+
+def edges_to_adj(n, edge_index):
+ adj = [[] for _ in range(n)]
+ ei = np.asarray(edge_index)
+ for a, b in zip(ei[0].tolist(), ei[1].tolist()):
+ adj[a].append(b)
+ return adj
+
+
+def wl_refine(adjs, inits=None, max_rounds=None):
+ """Dataset-level 1-WL. Returns (node_rounds, ghist_rounds, conv_round).
+ node_rounds[r][g] = int color array (global labels) of graph g after r rounds.
+ ghist_rounds[r][g] = canonical color histogram (hashable) of graph g after r rounds.
+ conv_round = round index at which the global partition stabilized.
+ """
+ if inits is None:
+ inits = [np.zeros(len(a), dtype=np.int64) for a in adjs]
+ else:
+ inits = [np.asarray(x, dtype=np.int64) for x in inits]
+ if max_rounds is None:
+ max_rounds = max((len(a) for a in adjs), default=0) + 2
+
+ d = {}
+ def lab(s):
+ v = d.get(s)
+ if v is None:
+ v = len(d); d[s] = v
+ return v
+
+ cur = [np.array([lab(('i', int(c))) for c in init], dtype=np.int64) for init in inits]
+ node_rounds = [cur]
+ nclasses = [len(d)]
+
+ for _r in range(max_rounds):
+ d = {}
+ nxt = []
+ for adj in adjs:
+ c = cur_g = node_rounds[-1][len(nxt)]
+ arr = np.empty(len(adj), dtype=np.int64)
+ for v in range(len(adj)):
+ sig = (int(c[v]), tuple(sorted(int(c[u]) for u in adj[v])))
+ arr[v] = lab(sig)
+ nxt.append(arr)
+ node_rounds.append(nxt)
+ nclasses.append(len(d))
+ if nclasses[-1] == nclasses[-2]: # global #classes stopped growing -> converged
+ break
+
+ conv_round = len(node_rounds) - 1
+ ghist_rounds = [[_hist(c) for c in nr] for nr in node_rounds]
+ return node_rounds, ghist_rounds, conv_round
+
+
+def _hist(colors):
+ return tuple(sorted(Counter(colors.tolist()).items()))
+
+
+def graph_colors_at(ghist_rounds, conv_round, L):
+ return ghist_rounds[min(L, conv_round)]
+
+
+# ---------- classification attribution ----------
+def attribute_classification(ghist_rounds, conv_round, L, y, train_idx, eval_idx):
+ y = np.asarray(y)
+ conv = ghist_rounds[conv_round]
+ Lr = min(L, conv_round)
+ lr = ghist_rounds[Lr]
+ conv_train, lr_train = defaultdict(list), defaultdict(list)
+ for i in train_idx:
+ conv_train[conv[i]].append(int(y[i]))
+ lr_train[lr[i]].append(int(y[i]))
+
+ def pure(dct, key):
+ labs = dct.get(key)
+ return labs is not None and len(set(labs)) == 1
+
+ def majority(dct, key):
+ labs = dct.get(key)
+ return Counter(labs).most_common(1)[0][0] if labs else None
+
+ buckets = {}
+ wl_opt = lr_opt = 0
+ for i in eval_idx:
+ if conv[i] not in conv_train:
+ buckets[i] = 'novel'
+ elif not pure(conv_train, conv[i]):
+ buckets[i] = 'H2'
+ elif not pure(lr_train, lr[i]):
+ buckets[i] = 'H1a_depth'
+ else:
+ buckets[i] = 'H1b_opt'
+ if majority(conv_train, conv[i]) == int(y[i]):
+ wl_opt += 1
+ if majority(lr_train, lr[i]) == int(y[i]):
+ lr_opt += 1
+ n = len(eval_idx)
+ return {
+ 'buckets': buckets,
+ 'counts': dict(Counter(buckets.values())),
+ 'wl_optimal_acc_converged': wl_opt / n, # best ANY MPNN can do
+ 'wl_optimal_acc_Ldepth': lr_opt / n, # best L-layer MPNN can do
+ 'L_used': Lr, 'conv_round': conv_round,
+ }
+
+
+# ---------- regression decomposition ----------
+def decompose_regression(ghist_rounds, conv_round, L, y, train_idx, eval_idx):
+ """H2 floor = ORACLE within-color variance on FULL data (best possible function of the WL
+ color: do same-color graphs share the target?). This is the true information ceiling and is
+ NOT confounded by train/test coverage. The train-fitted floors are also reported to expose
+ how much apparent error is really novel-color generalization, plus coverage fractions."""
+ y = np.asarray(y, dtype=np.float64)
+ conv = ghist_rounds[conv_round]
+ Lr = min(L, conv_round)
+ lr = ghist_rounds[Lr]
+ full_idx = list(range(len(y)))
+
+ # oracle: best constant per converged color over ALL data -> irreducible by any MPNN
+ conv_mean_full = _group_mean(conv, y, full_idx)
+ e_oracle = np.array([conv_mean_full[conv[i]] - y[i] for i in eval_idx])
+
+ # train-fitted (achievable with this split); fallback to global mean on unseen colors
+ conv_mean_tr = _group_mean(conv, y, train_idx)
+ lr_mean_tr = _group_mean(lr, y, train_idx)
+ gmean = float(y[list(train_idx)].mean())
+ e_conv_tr = np.array([conv_mean_tr.get(conv[i], gmean) - y[i] for i in eval_idx])
+ e_lr_tr = np.array([lr_mean_tr.get(lr[i], gmean) - y[i] for i in eval_idx])
+
+ conv_count = Counter(conv[i] for i in full_idx)
+ train_colors = set(conv[i] for i in train_idx)
+ frac_unseen = float(np.mean([conv[i] not in train_colors for i in eval_idx]))
+ frac_singleton = float(np.mean([conv_count[conv[i]] == 1 for i in eval_idx]))
+ return {
+ 'mse_floor_oracle_H2': float((e_oracle ** 2).mean()), # TRUE 1-WL ceiling
+ 'mse_floor_converged_train': float((e_conv_tr ** 2).mean()),
+ 'mse_floor_Ldepth_train': float((e_lr_tr ** 2).mean()),
+ 'frac_test_unseen_color': frac_unseen,
+ 'frac_test_singleton_color': frac_singleton,
+ 'L_used': Lr, 'conv_round': conv_round,
+ 'var_target_eval': float(y[list(eval_idx)].var()),
+ }
+
+
+def _group_mean(colors, y, idx):
+ acc = defaultdict(list)
+ for i in idx:
+ acc[colors[i]].append(float(y[i]))
+ return {k: float(np.mean(v)) for k, v in acc.items()}