summaryrefslogtreecommitdiff
path: root/diag/peptides_depth.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/peptides_depth.py')
-rw-r--r--diag/peptides_depth.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/diag/peptides_depth.py b/diag/peptides_depth.py
new file mode 100644
index 0000000..d751b31
--- /dev/null
+++ b/diag/peptides_depth.py
@@ -0,0 +1,66 @@
+"""Depth-resolution analysis on LRGB Peptides-struct (long-range, real, large graphs).
+
+Best-achievable MSE for a depth-L MPNN == within-(L-round-WL-color) variance of the target.
+The curve over L localizes failure cause WITHOUT training:
+ floor(L=small) high -> under-reached signal still hidden
+ floor(L) - floor(converged) -> H1a: depth-recoverable (more iteration/depth helps)
+ floor(converged) -> H2 : 1-WL ceiling (irreducible by any MPNN)
+Targets are z-scored per dim, so floor is the fraction of variance unexplained (var=1 baseline).
+"""
+import numpy as np
+from collections import defaultdict
+from torch_geometric.datasets import LRGBDataset
+from diag import wl
+
+S = 4000 # graph subsample (for tractable pure-python WL)
+MAX_ROUNDS = 40 # cap (>> typical GIN depth; chains converge near their diameter)
+
+
+def floor_at(ghist, Y):
+ groups = defaultdict(list)
+ for i, c in enumerate(ghist):
+ groups[c].append(i)
+ sse = 0.0
+ for idxs in groups.values():
+ yy = Y[idxs]
+ sse += ((yy - yy.mean(0)) ** 2).sum()
+ return sse / (len(ghist) * Y.shape[1]), len(groups)
+
+
+def main():
+ ds = LRGBDataset(root='/home/yurenh2/rrog/data/lrgb', name='Peptides-struct', split='train')
+ graphs = [ds[i] for i in range(min(S, len(ds)))]
+ fmap = {}
+ def fid(row):
+ t = tuple(row)
+ if t not in fmap:
+ fmap[t] = len(fmap)
+ return fmap[t]
+ adjs, inits, Y = [], [], []
+ for g in graphs:
+ adjs.append(wl.edges_to_adj(g.num_nodes, g.edge_index.numpy()))
+ inits.append(np.array([fid(r) for r in g.x.tolist()], dtype=np.int64))
+ Y.append(g.y.numpy().reshape(-1))
+ Y = np.stack(Y).astype(np.float64)
+ Y = (Y - Y.mean(0)) / (Y.std(0) + 1e-8) # z-score per target
+ print(f"subsample={len(graphs)} graphs, {len(fmap)} distinct node-feature ids, targets={Y.shape[1]}")
+
+ import time; t0 = time.time()
+ node_rounds, ghist_rounds, conv = wl.wl_refine(adjs, inits=inits, max_rounds=MAX_ROUNDS)
+ print(f"WL refined to round {conv} (cap {MAX_ROUNDS}) in {time.time()-t0:.1f}s")
+
+ print(f"{'L':>4} {'floor_MSE(std)':>14} {'%var_unexpl':>12} {'#graph_colors':>14}")
+ floors = {}
+ for L in [0, 1, 2, 3, 4, 5, 8, 16, 32, conv]:
+ r = min(L, conv)
+ f, nc = floor_at(ghist_rounds[r], Y)
+ floors[L] = f
+ print(f"{L:>4} {f:>14.4f} {100*f:>11.1f}% {nc:>14}")
+ h2 = floors[conv]
+ for Lg in [4, 5]:
+ print(f"\nAt GIN depth L={Lg}: H2 ceiling={h2:.3f} | depth-recoverable H1a (floor[{Lg}]-H2)"
+ f"={floors[Lg]-h2:.3f} | already-reachable={1-floors[Lg]:.3f} of var")
+
+
+if __name__ == "__main__":
+ main()