diff options
Diffstat (limited to 'rrog/backbones.py')
| -rw-r--r-- | rrog/backbones.py | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/rrog/backbones.py b/rrog/backbones.py new file mode 100644 index 0000000..81bfb9b --- /dev/null +++ b/rrog/backbones.py @@ -0,0 +1,72 @@ +from rrog.registry import ComputeSpec, ModifierSpec, ViewSpec, by_name + + +VIEWS = [ + ViewSpec("gin", "message-passing", "2d", 1, "implemented", + "Plain GINConv message passing."), + ViewSpec("gine", "message-passing", "2d", 2, "implemented", + "Edge-aware GIN variant. ZINC uses a learned constant edge token; OGB uses bond features."), + ViewSpec("gcn", "message-passing", "2d", 3, "implemented"), + ViewSpec("graphsage", "message-passing", "2d", 4, "implemented"), + ViewSpec("gatv2", "attention-mpnn", "2d", 5, "implemented"), + ViewSpec("graphconv", "message-passing", "2d", 6, "implemented"), + ViewSpec("transformer", "attention-mpnn", "2d", 7, "implemented"), + ViewSpec("pna", "message-passing", "2d", 8, "implemented", + "Requires degree histogram from the train split."), + ViewSpec("gen", "message-passing", "2d", 9, "implemented"), + ViewSpec("film", "message-passing", "2d", 10, "implemented"), + ViewSpec("resgated", "message-passing", "2d", 11, "implemented"), + ViewSpec("tag", "higher-order-hop", "2d", 12, "implemented"), + ViewSpec("sgc", "propagation", "2d", 13, "implemented"), + ViewSpec("cheb", "spectral", "2d", 14, "implemented"), + ViewSpec("arma", "spectral", "2d", 15, "implemented"), + ViewSpec("mf", "message-passing", "2d", 16, "implemented"), + ViewSpec("appnp", "propagation", "2d", 17, "implemented"), + ViewSpec("mixhop", "higher-order-hop", "2d", 18), + ViewSpec("gps", "hybrid-local-global", "2d", 19), + ViewSpec("graphormer", "global-attention", "2d", 20), + ViewSpec("san", "spectral-attention", "2d", 21), + ViewSpec("mpnn", "message-passing", "2d", 22), + ViewSpec("schnet", "continuous-filter", "3d", 23), + ViewSpec("dimenetpp", "angle-aware", "3d", 24), + ViewSpec("painn", "equivariant", "3d", 25), + ViewSpec("gemnet", "equivariant", "3d", 26), + ViewSpec("egnn", "equivariant", "3d", 27), + ViewSpec("equiformer", "equivariant", "3d", 28), + ViewSpec("mace", "equivariant", "3d", 29), +] + + +COMPUTES = [ + ComputeSpec("classic", "baseline", 1, "implemented", "Standard one-forward GNN baseline; no RRoG compute."), + ComputeSpec("view-only", "none", 2, "implemented", "RRoG view module only; no recursive compute."), + ComputeSpec("fixed-rrog", "recursive", 3, "implemented", "Fixed-depth edge-free y/z compute."), + ComputeSpec("rrog-act", "recursive-act", 4, "implemented", "Persistent full ACT recycling for graph batches."), + ComputeSpec("node-mlp", "recursive", 4), + ComputeSpec("gru-rrog", "recursive-gated", 5), + ComputeSpec("set-attn-core", "edge-free-attention", 6), + ComputeSpec("perceiver-core", "latent-attention", 7), + ComputeSpec("global-token-mixer", "token-mixer", 8), + ComputeSpec("equivariant-core", "3d-equivariant", 9), +] + + +MODIFIERS = [ + ModifierSpec("none", "none", 1, "implemented"), + ModifierSpec("dfa-gnn", "backward", 2, "planned", + "Non-BP/direct-feedback-style training; start on node classification."), + ModifierSpec("kaft", "backward", 3, "planned", + "User project in ../graph-grape; low priority until main table is stable."), + ModifierSpec("deep-supervision", "training", 4), + ModifierSpec("sam", "optimizer", 5), + ModifierSpec("lap-pe", "feature", 6), + ModifierSpec("rwse", "feature", 7), + ModifierSpec("virtual-node", "feature", 8), + ModifierSpec("dropedge", "regularization", 9), + ModifierSpec("flag", "augmentation", 10), +] + + +VIEW_BY_NAME = by_name(VIEWS) +COMPUTE_BY_NAME = by_name(COMPUTES) +MODIFIER_BY_NAME = by_name(MODIFIERS) |
