diff options
Diffstat (limited to 'diag/models.py')
| -rw-r--r-- | diag/models.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/diag/models.py b/diag/models.py new file mode 100644 index 0000000..0aa3106 --- /dev/null +++ b/diag/models.py @@ -0,0 +1,42 @@ +"""GIN (1-WL-tight) and GCN (<1-WL) backbones for the diagnosis.""" +import torch.nn as nn +from torch_geometric.nn import GINConv, GCNConv, global_add_pool, global_mean_pool + + +def _mlp(d_in, d_hid, d_out): + return nn.Sequential(nn.Linear(d_in, d_hid), nn.BatchNorm1d(d_hid), nn.ReLU(), + nn.Linear(d_hid, d_out)) + + +class GIN(nn.Module): + """Sum aggregation + MLP update -> injective on multisets -> matches 1-WL.""" + def __init__(self, in_dim, hidden=64, layers=4, out_dim=10): + super().__init__() + self.convs = nn.ModuleList() + d = in_dim + for _ in range(layers): + self.convs.append(GINConv(_mlp(d, hidden, hidden), train_eps=True)) + d = hidden + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index).relu() + return self.head(global_add_pool(x, batch)) + + +class GCN(nn.Module): + """Mean (normalized) aggregation -> non-injective -> strictly below 1-WL (reference baseline).""" + def __init__(self, in_dim, hidden=64, layers=4, out_dim=10): + super().__init__() + self.convs = nn.ModuleList() + d = in_dim + for _ in range(layers): + self.convs.append(GCNConv(d, hidden)) + d = hidden + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index).relu() + return self.head(global_mean_pool(x, batch)) |
