summaryrefslogtreecommitdiff
path: root/rrog/backbones.py
blob: 81bfb9be0d2db9a71637dbeef1ba28399fe39acb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from rrog.registry import ComputeSpec, ModifierSpec, ViewSpec, by_name


VIEWS = [
    ViewSpec("gin", "message-passing", "2d", 1, "implemented",
             "Plain GINConv message passing."),
    ViewSpec("gine", "message-passing", "2d", 2, "implemented",
             "Edge-aware GIN variant. ZINC uses a learned constant edge token; OGB uses bond features."),
    ViewSpec("gcn", "message-passing", "2d", 3, "implemented"),
    ViewSpec("graphsage", "message-passing", "2d", 4, "implemented"),
    ViewSpec("gatv2", "attention-mpnn", "2d", 5, "implemented"),
    ViewSpec("graphconv", "message-passing", "2d", 6, "implemented"),
    ViewSpec("transformer", "attention-mpnn", "2d", 7, "implemented"),
    ViewSpec("pna", "message-passing", "2d", 8, "implemented",
             "Requires degree histogram from the train split."),
    ViewSpec("gen", "message-passing", "2d", 9, "implemented"),
    ViewSpec("film", "message-passing", "2d", 10, "implemented"),
    ViewSpec("resgated", "message-passing", "2d", 11, "implemented"),
    ViewSpec("tag", "higher-order-hop", "2d", 12, "implemented"),
    ViewSpec("sgc", "propagation", "2d", 13, "implemented"),
    ViewSpec("cheb", "spectral", "2d", 14, "implemented"),
    ViewSpec("arma", "spectral", "2d", 15, "implemented"),
    ViewSpec("mf", "message-passing", "2d", 16, "implemented"),
    ViewSpec("appnp", "propagation", "2d", 17, "implemented"),
    ViewSpec("mixhop", "higher-order-hop", "2d", 18),
    ViewSpec("gps", "hybrid-local-global", "2d", 19),
    ViewSpec("graphormer", "global-attention", "2d", 20),
    ViewSpec("san", "spectral-attention", "2d", 21),
    ViewSpec("mpnn", "message-passing", "2d", 22),
    ViewSpec("schnet", "continuous-filter", "3d", 23),
    ViewSpec("dimenetpp", "angle-aware", "3d", 24),
    ViewSpec("painn", "equivariant", "3d", 25),
    ViewSpec("gemnet", "equivariant", "3d", 26),
    ViewSpec("egnn", "equivariant", "3d", 27),
    ViewSpec("equiformer", "equivariant", "3d", 28),
    ViewSpec("mace", "equivariant", "3d", 29),
]


COMPUTES = [
    ComputeSpec("classic", "baseline", 1, "implemented", "Standard one-forward GNN baseline; no RRoG compute."),
    ComputeSpec("view-only", "none", 2, "implemented", "RRoG view module only; no recursive compute."),
    ComputeSpec("fixed-rrog", "recursive", 3, "implemented", "Fixed-depth edge-free y/z compute."),
    ComputeSpec("rrog-act", "recursive-act", 4, "implemented", "Persistent full ACT recycling for graph batches."),
    ComputeSpec("node-mlp", "recursive", 4),
    ComputeSpec("gru-rrog", "recursive-gated", 5),
    ComputeSpec("set-attn-core", "edge-free-attention", 6),
    ComputeSpec("perceiver-core", "latent-attention", 7),
    ComputeSpec("global-token-mixer", "token-mixer", 8),
    ComputeSpec("equivariant-core", "3d-equivariant", 9),
]


MODIFIERS = [
    ModifierSpec("none", "none", 1, "implemented"),
    ModifierSpec("dfa-gnn", "backward", 2, "planned",
                 "Non-BP/direct-feedback-style training; start on node classification."),
    ModifierSpec("kaft", "backward", 3, "planned",
                 "User project in ../graph-grape; low priority until main table is stable."),
    ModifierSpec("deep-supervision", "training", 4),
    ModifierSpec("sam", "optimizer", 5),
    ModifierSpec("lap-pe", "feature", 6),
    ModifierSpec("rwse", "feature", 7),
    ModifierSpec("virtual-node", "feature", 8),
    ModifierSpec("dropedge", "regularization", 9),
    ModifierSpec("flag", "augmentation", 10),
]


VIEW_BY_NAME = by_name(VIEWS)
COMPUTE_BY_NAME = by_name(COMPUTES)
MODIFIER_BY_NAME = by_name(MODIFIERS)