"""Depth-resolution analysis on LRGB Peptides-struct (long-range, real, large graphs). Best-achievable MSE for a depth-L MPNN == within-(L-round-WL-color) variance of the target. The curve over L localizes failure cause WITHOUT training: floor(L=small) high -> under-reached signal still hidden floor(L) - floor(converged) -> H1a: depth-recoverable (more iteration/depth helps) floor(converged) -> H2 : 1-WL ceiling (irreducible by any MPNN) Targets are z-scored per dim, so floor is the fraction of variance unexplained (var=1 baseline). """ import numpy as np from collections import defaultdict from torch_geometric.datasets import LRGBDataset from diag import wl S = 4000 # graph subsample (for tractable pure-python WL) MAX_ROUNDS = 40 # cap (>> typical GIN depth; chains converge near their diameter) def floor_at(ghist, Y): groups = defaultdict(list) for i, c in enumerate(ghist): groups[c].append(i) sse = 0.0 for idxs in groups.values(): yy = Y[idxs] sse += ((yy - yy.mean(0)) ** 2).sum() return sse / (len(ghist) * Y.shape[1]), len(groups) def main(): ds = LRGBDataset(root='/home/yurenh2/rrog/data/lrgb', name='Peptides-struct', split='train') graphs = [ds[i] for i in range(min(S, len(ds)))] fmap = {} def fid(row): t = tuple(row) if t not in fmap: fmap[t] = len(fmap) return fmap[t] adjs, inits, Y = [], [], [] for g in graphs: adjs.append(wl.edges_to_adj(g.num_nodes, g.edge_index.numpy())) inits.append(np.array([fid(r) for r in g.x.tolist()], dtype=np.int64)) Y.append(g.y.numpy().reshape(-1)) Y = np.stack(Y).astype(np.float64) Y = (Y - Y.mean(0)) / (Y.std(0) + 1e-8) # z-score per target print(f"subsample={len(graphs)} graphs, {len(fmap)} distinct node-feature ids, targets={Y.shape[1]}") import time; t0 = time.time() node_rounds, ghist_rounds, conv = wl.wl_refine(adjs, inits=inits, max_rounds=MAX_ROUNDS) print(f"WL refined to round {conv} (cap {MAX_ROUNDS}) in {time.time()-t0:.1f}s") print(f"{'L':>4} {'floor_MSE(std)':>14} {'%var_unexpl':>12} {'#graph_colors':>14}") floors = {} for L in [0, 1, 2, 3, 4, 5, 8, 16, 32, conv]: r = min(L, conv) f, nc = floor_at(ghist_rounds[r], Y) floors[L] = f print(f"{L:>4} {f:>14.4f} {100*f:>11.1f}% {nc:>14}") h2 = floors[conv] for Lg in [4, 5]: print(f"\nAt GIN depth L={Lg}: H2 ceiling={h2:.3f} | depth-recoverable H1a (floor[{Lg}]-H2)" f"={floors[Lg]-h2:.3f} | already-reachable={1-floors[Lg]:.3f} of var") if __name__ == "__main__": main()