summaryrefslogtreecommitdiff
path: root/diag/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/models.py')
-rw-r--r--diag/models.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/diag/models.py b/diag/models.py
new file mode 100644
index 0000000..0aa3106
--- /dev/null
+++ b/diag/models.py
@@ -0,0 +1,42 @@
+"""GIN (1-WL-tight) and GCN (<1-WL) backbones for the diagnosis."""
+import torch.nn as nn
+from torch_geometric.nn import GINConv, GCNConv, global_add_pool, global_mean_pool
+
+
+def _mlp(d_in, d_hid, d_out):
+ return nn.Sequential(nn.Linear(d_in, d_hid), nn.BatchNorm1d(d_hid), nn.ReLU(),
+ nn.Linear(d_hid, d_out))
+
+
+class GIN(nn.Module):
+ """Sum aggregation + MLP update -> injective on multisets -> matches 1-WL."""
+ def __init__(self, in_dim, hidden=64, layers=4, out_dim=10):
+ super().__init__()
+ self.convs = nn.ModuleList()
+ d = in_dim
+ for _ in range(layers):
+ self.convs.append(GINConv(_mlp(d, hidden, hidden), train_eps=True))
+ d = hidden
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
+
+ def forward(self, x, edge_index, batch):
+ for conv in self.convs:
+ x = conv(x, edge_index).relu()
+ return self.head(global_add_pool(x, batch))
+
+
+class GCN(nn.Module):
+ """Mean (normalized) aggregation -> non-injective -> strictly below 1-WL (reference baseline)."""
+ def __init__(self, in_dim, hidden=64, layers=4, out_dim=10):
+ super().__init__()
+ self.convs = nn.ModuleList()
+ d = in_dim
+ for _ in range(layers):
+ self.convs.append(GCNConv(d, hidden))
+ d = hidden
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
+
+ def forward(self, x, edge_index, batch):
+ for conv in self.convs:
+ x = conv(x, edge_index).relu()
+ return self.head(global_mean_pool(x, batch))