"""GIN (1-WL-tight) and GCN (<1-WL) backbones for the diagnosis.""" import torch.nn as nn from torch_geometric.nn import GINConv, GCNConv, global_add_pool, global_mean_pool def _mlp(d_in, d_hid, d_out): return nn.Sequential(nn.Linear(d_in, d_hid), nn.BatchNorm1d(d_hid), nn.ReLU(), nn.Linear(d_hid, d_out)) class GIN(nn.Module): """Sum aggregation + MLP update -> injective on multisets -> matches 1-WL.""" def __init__(self, in_dim, hidden=64, layers=4, out_dim=10): super().__init__() self.convs = nn.ModuleList() d = in_dim for _ in range(layers): self.convs.append(GINConv(_mlp(d, hidden, hidden), train_eps=True)) d = hidden self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) def forward(self, x, edge_index, batch): for conv in self.convs: x = conv(x, edge_index).relu() return self.head(global_add_pool(x, batch)) class GCN(nn.Module): """Mean (normalized) aggregation -> non-injective -> strictly below 1-WL (reference baseline).""" def __init__(self, in_dim, hidden=64, layers=4, out_dim=10): super().__init__() self.convs = nn.ModuleList() d = in_dim for _ in range(layers): self.convs.append(GCNConv(d, hidden)) d = hidden self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) def forward(self, x, edge_index, batch): for conv in self.convs: x = conv(x, edge_index).relu() return self.head(global_mean_pool(x, batch))