1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
"""GIN (1-WL-tight) and GCN (<1-WL) backbones for the diagnosis."""
import torch.nn as nn
from torch_geometric.nn import GINConv, GCNConv, global_add_pool, global_mean_pool
def _mlp(d_in, d_hid, d_out):
return nn.Sequential(nn.Linear(d_in, d_hid), nn.BatchNorm1d(d_hid), nn.ReLU(),
nn.Linear(d_hid, d_out))
class GIN(nn.Module):
"""Sum aggregation + MLP update -> injective on multisets -> matches 1-WL."""
def __init__(self, in_dim, hidden=64, layers=4, out_dim=10):
super().__init__()
self.convs = nn.ModuleList()
d = in_dim
for _ in range(layers):
self.convs.append(GINConv(_mlp(d, hidden, hidden), train_eps=True))
d = hidden
self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
def forward(self, x, edge_index, batch):
for conv in self.convs:
x = conv(x, edge_index).relu()
return self.head(global_add_pool(x, batch))
class GCN(nn.Module):
"""Mean (normalized) aggregation -> non-injective -> strictly below 1-WL (reference baseline)."""
def __init__(self, in_dim, hidden=64, layers=4, out_dim=10):
super().__init__()
self.convs = nn.ModuleList()
d = in_dim
for _ in range(layers):
self.convs.append(GCNConv(d, hidden))
d = hidden
self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
def forward(self, x, edge_index, batch):
for conv in self.convs:
x = conv(x, edge_index).relu()
return self.head(global_mean_pool(x, batch))
|