"""Control: plain FEEDFORWARD (independent-weight, NO recursion, NO deep-supervision) GNN on graph 3-coloring, to test whether the recursive gcn/gat collapse (.003/.04) is caused by the RECURSION (shared-weight repeated -> oversmoothing) or is intrinsic to gcn/gat on coloring. Sweep conv in {gcn,gat} x depth L in {4,8,16}. L=4 ~ normal usage; L=16 tests whether deep feedforward also oversmooths. If shallow FF colors well but recursive collapses -> recursion's fault. If FF is also ~0 at all L -> intrinsic. Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ff_color.py --conv gcn --L 4 --seed 0 """ import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.loader import DataLoader from diag.train_color import make_split, featurize, make_conv, conflict_loss, solve_stats class FF(nn.Module): def __init__(self, in_dim, hidden, k, L, conv): super().__init__() self.conv_type = conv self.lin_in = nn.Linear(in_dim, hidden) self.convs = nn.ModuleList([make_conv(conv, hidden) for _ in range(L)]) self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(L)]) self.head = nn.Linear(hidden, k) def forward(self, xin, ei, batch=None, noise=False): h = self.lin_in(xin) for conv, bn in zip(self.convs, self.bns): h = bn(conv(h, ei)).relu() return [self.head(h)] def main(): ap = argparse.ArgumentParser() ap.add_argument('--conv', choices=['gcn', 'gat', 'gin', 'sage'], default='gcn') ap.add_argument('--L', type=int, default=4) ap.add_argument('--epochs', type=int, default=150) ap.add_argument('--seed', type=int, default=0) args = ap.parse_args() torch.manual_seed(args.seed); np.random.seed(args.seed) dev = 'cuda' if torch.cuda.is_available() else 'cpu' tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16) te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16) in_dim = tr[0]['xin'].shape[1] trl = DataLoader([Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr], batch_size=32, shuffle=True, drop_last=True) model = FF(in_dim, 128, 3, args.L, args.conv).to(dev) opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} best = -1 for ep in range(args.epochs): model.train() for b in trl: b = b.to(dev); opt.zero_grad() loss = conflict_loss(model(b.x, b.edge_index)[-1], b.edge_index) loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() with torch.no_grad(): for kk, v in model.state_dict().items(): ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v) sched.step() if (ep + 1) % 30 == 0 or ep == args.epochs - 1: bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()} model.load_state_dict(ema); sr, _ = solve_stats(model, te, dev, sample=300) best = max(best, sr); model.load_state_dict(bk) print(f"[ff/{args.conv}/L{args.L}/s{args.seed}] solve={best:.3f}", flush=True) if __name__ == "__main__": main()