summaryrefslogtreecommitdiff
path: root/diag/models.py
blob: 0aa31064d0b574ce26edb956c4cc370083f840fa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""GIN (1-WL-tight) and GCN (<1-WL) backbones for the diagnosis."""
import torch.nn as nn
from torch_geometric.nn import GINConv, GCNConv, global_add_pool, global_mean_pool


def _mlp(d_in, d_hid, d_out):
    return nn.Sequential(nn.Linear(d_in, d_hid), nn.BatchNorm1d(d_hid), nn.ReLU(),
                         nn.Linear(d_hid, d_out))


class GIN(nn.Module):
    """Sum aggregation + MLP update -> injective on multisets -> matches 1-WL."""
    def __init__(self, in_dim, hidden=64, layers=4, out_dim=10):
        super().__init__()
        self.convs = nn.ModuleList()
        d = in_dim
        for _ in range(layers):
            self.convs.append(GINConv(_mlp(d, hidden, hidden), train_eps=True))
            d = hidden
        self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))

    def forward(self, x, edge_index, batch):
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        return self.head(global_add_pool(x, batch))


class GCN(nn.Module):
    """Mean (normalized) aggregation -> non-injective -> strictly below 1-WL (reference baseline)."""
    def __init__(self, in_dim, hidden=64, layers=4, out_dim=10):
        super().__init__()
        self.convs = nn.ModuleList()
        d = in_dim
        for _ in range(layers):
            self.convs.append(GCNConv(d, hidden))
            d = hidden
        self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))

    def forward(self, x, edge_index, batch):
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        return self.head(global_mean_pool(x, batch))