summaryrefslogtreecommitdiff
path: root/diag/ff_color.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/ff_color.py')
-rw-r--r--diag/ff_color.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/diag/ff_color.py b/diag/ff_color.py
new file mode 100644
index 0000000..29c3b45
--- /dev/null
+++ b/diag/ff_color.py
@@ -0,0 +1,75 @@
+"""Control: plain FEEDFORWARD (independent-weight, NO recursion, NO deep-supervision) GNN on
+graph 3-coloring, to test whether the recursive gcn/gat collapse (.003/.04) is caused by the
+RECURSION (shared-weight repeated -> oversmoothing) or is intrinsic to gcn/gat on coloring.
+
+Sweep conv in {gcn,gat} x depth L in {4,8,16}. L=4 ~ normal usage; L=16 tests whether deep
+feedforward also oversmooths. If shallow FF colors well but recursive collapses -> recursion's
+fault. If FF is also ~0 at all L -> intrinsic.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ff_color.py --conv gcn --L 4 --seed 0
+"""
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from diag.train_color import make_split, featurize, make_conv, conflict_loss, solve_stats
+
+
+class FF(nn.Module):
+ def __init__(self, in_dim, hidden, k, L, conv):
+ super().__init__()
+ self.conv_type = conv
+ self.lin_in = nn.Linear(in_dim, hidden)
+ self.convs = nn.ModuleList([make_conv(conv, hidden) for _ in range(L)])
+ self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(L)])
+ self.head = nn.Linear(hidden, k)
+
+ def forward(self, xin, ei, batch=None, noise=False):
+ h = self.lin_in(xin)
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, ei)).relu()
+ return [self.head(h)]
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gcn', 'gat', 'gin', 'sage'], default='gcn')
+ ap.add_argument('--L', type=int, default=4)
+ ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ in_dim = tr[0]['xin'].shape[1]
+ trl = DataLoader([Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr],
+ batch_size=32, shuffle=True, drop_last=True)
+ model = FF(in_dim, 128, 3, args.L, args.conv).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ best = -1
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = conflict_loss(model(b.x, b.edge_index)[-1], b.edge_index)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 30 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr, _ = solve_stats(model, te, dev, sample=300)
+ best = max(best, sr); model.load_state_dict(bk)
+ print(f"[ff/{args.conv}/L{args.L}/s{args.seed}] solve={best:.3f}", flush=True)
+
+
+if __name__ == "__main__":
+ main()