diff options
Diffstat (limited to 'diag/peptides_depth.py')
| -rw-r--r-- | diag/peptides_depth.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/diag/peptides_depth.py b/diag/peptides_depth.py new file mode 100644 index 0000000..d751b31 --- /dev/null +++ b/diag/peptides_depth.py @@ -0,0 +1,66 @@ +"""Depth-resolution analysis on LRGB Peptides-struct (long-range, real, large graphs). + +Best-achievable MSE for a depth-L MPNN == within-(L-round-WL-color) variance of the target. +The curve over L localizes failure cause WITHOUT training: + floor(L=small) high -> under-reached signal still hidden + floor(L) - floor(converged) -> H1a: depth-recoverable (more iteration/depth helps) + floor(converged) -> H2 : 1-WL ceiling (irreducible by any MPNN) +Targets are z-scored per dim, so floor is the fraction of variance unexplained (var=1 baseline). +""" +import numpy as np +from collections import defaultdict +from torch_geometric.datasets import LRGBDataset +from diag import wl + +S = 4000 # graph subsample (for tractable pure-python WL) +MAX_ROUNDS = 40 # cap (>> typical GIN depth; chains converge near their diameter) + + +def floor_at(ghist, Y): + groups = defaultdict(list) + for i, c in enumerate(ghist): + groups[c].append(i) + sse = 0.0 + for idxs in groups.values(): + yy = Y[idxs] + sse += ((yy - yy.mean(0)) ** 2).sum() + return sse / (len(ghist) * Y.shape[1]), len(groups) + + +def main(): + ds = LRGBDataset(root='/home/yurenh2/rrog/data/lrgb', name='Peptides-struct', split='train') + graphs = [ds[i] for i in range(min(S, len(ds)))] + fmap = {} + def fid(row): + t = tuple(row) + if t not in fmap: + fmap[t] = len(fmap) + return fmap[t] + adjs, inits, Y = [], [], [] + for g in graphs: + adjs.append(wl.edges_to_adj(g.num_nodes, g.edge_index.numpy())) + inits.append(np.array([fid(r) for r in g.x.tolist()], dtype=np.int64)) + Y.append(g.y.numpy().reshape(-1)) + Y = np.stack(Y).astype(np.float64) + Y = (Y - Y.mean(0)) / (Y.std(0) + 1e-8) # z-score per target + print(f"subsample={len(graphs)} graphs, {len(fmap)} distinct node-feature ids, targets={Y.shape[1]}") + + import time; t0 = time.time() + node_rounds, ghist_rounds, conv = wl.wl_refine(adjs, inits=inits, max_rounds=MAX_ROUNDS) + print(f"WL refined to round {conv} (cap {MAX_ROUNDS}) in {time.time()-t0:.1f}s") + + print(f"{'L':>4} {'floor_MSE(std)':>14} {'%var_unexpl':>12} {'#graph_colors':>14}") + floors = {} + for L in [0, 1, 2, 3, 4, 5, 8, 16, 32, conv]: + r = min(L, conv) + f, nc = floor_at(ghist_rounds[r], Y) + floors[L] = f + print(f"{L:>4} {f:>14.4f} {100*f:>11.1f}% {nc:>14}") + h2 = floors[conv] + for Lg in [4, 5]: + print(f"\nAt GIN depth L={Lg}: H2 ceiling={h2:.3f} | depth-recoverable H1a (floor[{Lg}]-H2)" + f"={floors[Lg]-h2:.3f} | already-reachable={1-floors[Lg]:.3f} of var") + + +if __name__ == "__main__": + main() |
