1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
|
"""Control: plain FEEDFORWARD (independent-weight, NO recursion, NO deep-supervision) GNN on
graph 3-coloring, to test whether the recursive gcn/gat collapse (.003/.04) is caused by the
RECURSION (shared-weight repeated -> oversmoothing) or is intrinsic to gcn/gat on coloring.
Sweep conv in {gcn,gat} x depth L in {4,8,16}. L=4 ~ normal usage; L=16 tests whether deep
feedforward also oversmooths. If shallow FF colors well but recursive collapses -> recursion's
fault. If FF is also ~0 at all L -> intrinsic.
Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ff_color.py --conv gcn --L 4 --seed 0
"""
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from diag.train_color import make_split, featurize, make_conv, conflict_loss, solve_stats
class FF(nn.Module):
def __init__(self, in_dim, hidden, k, L, conv):
super().__init__()
self.conv_type = conv
self.lin_in = nn.Linear(in_dim, hidden)
self.convs = nn.ModuleList([make_conv(conv, hidden) for _ in range(L)])
self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(L)])
self.head = nn.Linear(hidden, k)
def forward(self, xin, ei, batch=None, noise=False):
h = self.lin_in(xin)
for conv, bn in zip(self.convs, self.bns):
h = bn(conv(h, ei)).relu()
return [self.head(h)]
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--conv', choices=['gcn', 'gat', 'gin', 'sage'], default='gcn')
ap.add_argument('--L', type=int, default=4)
ap.add_argument('--epochs', type=int, default=150)
ap.add_argument('--seed', type=int, default=0)
args = ap.parse_args()
torch.manual_seed(args.seed); np.random.seed(args.seed)
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
in_dim = tr[0]['xin'].shape[1]
trl = DataLoader([Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr],
batch_size=32, shuffle=True, drop_last=True)
model = FF(in_dim, 128, 3, args.L, args.conv).to(dev)
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
best = -1
for ep in range(args.epochs):
model.train()
for b in trl:
b = b.to(dev); opt.zero_grad()
loss = conflict_loss(model(b.x, b.edge_index)[-1], b.edge_index)
loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
with torch.no_grad():
for kk, v in model.state_dict().items():
ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
sched.step()
if (ep + 1) % 30 == 0 or ep == args.epochs - 1:
bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
model.load_state_dict(ema); sr, _ = solve_stats(model, te, dev, sample=300)
best = max(best, sr); model.load_state_dict(bk)
print(f"[ff/{args.conv}/L{args.L}/s{args.seed}] solve={best:.3f}", flush=True)
if __name__ == "__main__":
main()
|