summaryrefslogtreecommitdiff
path: root/diag/peptides_depth.py
blob: d751b317c5ffe49e557edd6d11d166fa1cfaf82e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Depth-resolution analysis on LRGB Peptides-struct (long-range, real, large graphs).

Best-achievable MSE for a depth-L MPNN == within-(L-round-WL-color) variance of the target.
The curve over L localizes failure cause WITHOUT training:
  floor(L=small)              high  -> under-reached signal still hidden
  floor(L) - floor(converged)       -> H1a: depth-recoverable (more iteration/depth helps)
  floor(converged)                  -> H2 : 1-WL ceiling (irreducible by any MPNN)
Targets are z-scored per dim, so floor is the fraction of variance unexplained (var=1 baseline).
"""
import numpy as np
from collections import defaultdict
from torch_geometric.datasets import LRGBDataset
from diag import wl

S = 4000          # graph subsample (for tractable pure-python WL)
MAX_ROUNDS = 40   # cap (>> typical GIN depth; chains converge near their diameter)


def floor_at(ghist, Y):
    groups = defaultdict(list)
    for i, c in enumerate(ghist):
        groups[c].append(i)
    sse = 0.0
    for idxs in groups.values():
        yy = Y[idxs]
        sse += ((yy - yy.mean(0)) ** 2).sum()
    return sse / (len(ghist) * Y.shape[1]), len(groups)


def main():
    ds = LRGBDataset(root='/home/yurenh2/rrog/data/lrgb', name='Peptides-struct', split='train')
    graphs = [ds[i] for i in range(min(S, len(ds)))]
    fmap = {}
    def fid(row):
        t = tuple(row)
        if t not in fmap:
            fmap[t] = len(fmap)
        return fmap[t]
    adjs, inits, Y = [], [], []
    for g in graphs:
        adjs.append(wl.edges_to_adj(g.num_nodes, g.edge_index.numpy()))
        inits.append(np.array([fid(r) for r in g.x.tolist()], dtype=np.int64))
        Y.append(g.y.numpy().reshape(-1))
    Y = np.stack(Y).astype(np.float64)
    Y = (Y - Y.mean(0)) / (Y.std(0) + 1e-8)   # z-score per target
    print(f"subsample={len(graphs)} graphs, {len(fmap)} distinct node-feature ids, targets={Y.shape[1]}")

    import time; t0 = time.time()
    node_rounds, ghist_rounds, conv = wl.wl_refine(adjs, inits=inits, max_rounds=MAX_ROUNDS)
    print(f"WL refined to round {conv} (cap {MAX_ROUNDS}) in {time.time()-t0:.1f}s")

    print(f"{'L':>4} {'floor_MSE(std)':>14} {'%var_unexpl':>12} {'#graph_colors':>14}")
    floors = {}
    for L in [0, 1, 2, 3, 4, 5, 8, 16, 32, conv]:
        r = min(L, conv)
        f, nc = floor_at(ghist_rounds[r], Y)
        floors[L] = f
        print(f"{L:>4} {f:>14.4f} {100*f:>11.1f}% {nc:>14}")
    h2 = floors[conv]
    for Lg in [4, 5]:
        print(f"\nAt GIN depth L={Lg}:  H2 ceiling={h2:.3f}  | depth-recoverable H1a (floor[{Lg}]-H2)"
              f"={floors[Lg]-h2:.3f}  | already-reachable={1-floors[Lg]:.3f} of var")


if __name__ == "__main__":
    main()