From d12722525fc010a3910b5152c72654a2ade5eac4 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 17 Jun 2026 11:19:27 -0500 Subject: Initial import --- diag/models.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 diag/models.py (limited to 'diag/models.py') diff --git a/diag/models.py b/diag/models.py new file mode 100644 index 0000000..0aa3106 --- /dev/null +++ b/diag/models.py @@ -0,0 +1,42 @@ +"""GIN (1-WL-tight) and GCN (<1-WL) backbones for the diagnosis.""" +import torch.nn as nn +from torch_geometric.nn import GINConv, GCNConv, global_add_pool, global_mean_pool + + +def _mlp(d_in, d_hid, d_out): + return nn.Sequential(nn.Linear(d_in, d_hid), nn.BatchNorm1d(d_hid), nn.ReLU(), + nn.Linear(d_hid, d_out)) + + +class GIN(nn.Module): + """Sum aggregation + MLP update -> injective on multisets -> matches 1-WL.""" + def __init__(self, in_dim, hidden=64, layers=4, out_dim=10): + super().__init__() + self.convs = nn.ModuleList() + d = in_dim + for _ in range(layers): + self.convs.append(GINConv(_mlp(d, hidden, hidden), train_eps=True)) + d = hidden + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index).relu() + return self.head(global_add_pool(x, batch)) + + +class GCN(nn.Module): + """Mean (normalized) aggregation -> non-injective -> strictly below 1-WL (reference baseline).""" + def __init__(self, in_dim, hidden=64, layers=4, out_dim=10): + super().__init__() + self.convs = nn.ModuleList() + d = in_dim + for _ in range(layers): + self.convs.append(GCNConv(d, hidden)) + d = hidden + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index).relu() + return self.head(global_mean_pool(x, batch)) -- cgit v1.2.3