diff options
Diffstat (limited to 'diag/ff_color.py')
| -rw-r--r-- | diag/ff_color.py | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/diag/ff_color.py b/diag/ff_color.py new file mode 100644 index 0000000..29c3b45 --- /dev/null +++ b/diag/ff_color.py @@ -0,0 +1,75 @@ +"""Control: plain FEEDFORWARD (independent-weight, NO recursion, NO deep-supervision) GNN on +graph 3-coloring, to test whether the recursive gcn/gat collapse (.003/.04) is caused by the +RECURSION (shared-weight repeated -> oversmoothing) or is intrinsic to gcn/gat on coloring. + +Sweep conv in {gcn,gat} x depth L in {4,8,16}. L=4 ~ normal usage; L=16 tests whether deep +feedforward also oversmooths. If shallow FF colors well but recursive collapses -> recursion's +fault. If FF is also ~0 at all L -> intrinsic. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ff_color.py --conv gcn --L 4 --seed 0 +""" +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from diag.train_color import make_split, featurize, make_conv, conflict_loss, solve_stats + + +class FF(nn.Module): + def __init__(self, in_dim, hidden, k, L, conv): + super().__init__() + self.conv_type = conv + self.lin_in = nn.Linear(in_dim, hidden) + self.convs = nn.ModuleList([make_conv(conv, hidden) for _ in range(L)]) + self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(L)]) + self.head = nn.Linear(hidden, k) + + def forward(self, xin, ei, batch=None, noise=False): + h = self.lin_in(xin) + for conv, bn in zip(self.convs, self.bns): + h = bn(conv(h, ei)).relu() + return [self.head(h)] + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--conv', choices=['gcn', 'gat', 'gin', 'sage'], default='gcn') + ap.add_argument('--L', type=int, default=4) + ap.add_argument('--epochs', type=int, default=150) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16) + te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16) + in_dim = tr[0]['xin'].shape[1] + trl = DataLoader([Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr], + batch_size=32, shuffle=True, drop_last=True) + model = FF(in_dim, 128, 3, args.L, args.conv).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + best = -1 + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + loss = conflict_loss(model(b.x, b.edge_index)[-1], b.edge_index) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + with torch.no_grad(): + for kk, v in model.state_dict().items(): + ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v) + sched.step() + if (ep + 1) % 30 == 0 or ep == args.epochs - 1: + bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + model.load_state_dict(ema); sr, _ = solve_stats(model, te, dev, sample=300) + best = max(best, sr); model.load_state_dict(bk) + print(f"[ff/{args.conv}/L{args.L}/s{args.seed}] solve={best:.3f}", flush=True) + + +if __name__ == "__main__": + main() |
