"""Training-based failure diagnosis on LRGB Peptides-struct (real, large, long-range). The WL partition instrument is vacuous here (graphs ~all distinguishable), so we diagnose by TRAINING and comparing: GIN(L) : standard 1-WL backbone at depth L GIN(L)+RNI : random node features = noise = beyond-1-WL symmetry breaker GCN(L) : sub-1-WL reference Reads: deeper helps -> long-range/under-reaching; RNI helps -> a real >1-WL ceiling that noise breaks; train< generalization; train high -> compute/optimization ceiling. Targets z-scored per dim; metric = standardized MAE (lower better). 11 targets. Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_real.py --conv gin --layers 5 --rni 0 """ import argparse, json, os, time import numpy as np import torch import torch.nn as nn from torch_geometric.datasets import LRGBDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GINConv, GCNConv, global_mean_pool ROOT = '/home/yurenh2/rrog/data/lrgb' OUT = '/home/yurenh2/rrog/runs' class Net(nn.Module): def __init__(self, col_sizes, hidden, layers, out_dim, conv='gin', rni=0): super().__init__() self.embs = nn.ModuleList([nn.Embedding(int(s), hidden) for s in col_sizes]) self.rni = rni self.lin_in = nn.Linear(hidden + rni, hidden) self.convs, self.bns = nn.ModuleList(), nn.ModuleList() for _ in range(layers): if conv == 'gin': mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)) self.convs.append(GINConv(mlp, train_eps=True)) else: self.convs.append(GCNConv(hidden, hidden)) self.bns.append(nn.BatchNorm1d(hidden)) self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) def forward(self, x, edge_index, batch): h = sum(emb(x[:, i]) for i, emb in enumerate(self.embs)) if self.rni: h = torch.cat([h, torch.randn(h.size(0), self.rni, device=h.device)], dim=1) h = self.lin_in(h) for conv, bn in zip(self.convs, self.bns): h = bn(conv(h, edge_index)).relu() return self.head(global_mean_pool(h, batch)) @torch.no_grad() def mae(model, loader, dev, ymu, ysd): model.eval(); se = n = 0.0 for b in loader: b = b.to(dev) o = model(b.x, b.edge_index, b.batch) se += (o - b.y).abs().sum().item(); n += b.y.numel() return se / n def main(): ap = argparse.ArgumentParser() ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin') ap.add_argument('--layers', type=int, default=5) ap.add_argument('--hidden', type=int, default=128) ap.add_argument('--rni', type=int, default=0) ap.add_argument('--epochs', type=int, default=150) ap.add_argument('--lr', type=float, default=1e-3) ap.add_argument('--bs', type=int, default=128) ap.add_argument('--seed', type=int, default=0) args = ap.parse_args() torch.manual_seed(args.seed); np.random.seed(args.seed) dev = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(OUT, exist_ok=True) tr = LRGBDataset(root=ROOT, name='Peptides-struct', split='train') va = LRGBDataset(root=ROOT, name='Peptides-struct', split='val') te = LRGBDataset(root=ROOT, name='Peptides-struct', split='test') # per-column embedding sizes + target standardization (train stats) col_max = None Ytr = [] for g in tr: m = g.x.max(0).values col_max = m if col_max is None else torch.maximum(col_max, m) Ytr.append(g.y.view(-1)) for ds in (va, te): for g in ds: col_max = torch.maximum(col_max, g.x.max(0).values) col_sizes = (col_max + 2).tolist() Ytr = torch.stack(Ytr) ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 def norm(ds): out = [] for g in ds: g = g.clone(); g.y = (g.y.view(1, -1) - ymu) / ysd out.append(g) return out trl = DataLoader(norm(tr), batch_size=args.bs, shuffle=True, drop_last=True) val = DataLoader(norm(va), batch_size=256) tel = DataLoader(norm(te), batch_size=256) trl_eval = DataLoader(norm(tr), batch_size=256) model = Net(col_sizes, args.hidden, args.layers, out_dim=11, conv=args.conv, rni=args.rni).to(dev) opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) lossf = nn.L1Loss() t0 = time.time(); best_val = 9e9; best = {} for ep in range(args.epochs): model.train() for b in trl: b = b.to(dev); opt.zero_grad() loss = lossf(model(b.x, b.edge_index, b.batch), b.y) loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() sched.step() if (ep + 1) % 15 == 0 or ep == args.epochs - 1: vm = mae(model, val, dev, ymu, ysd) if vm < best_val: best_val = vm best = {'ep': ep + 1, 'train_mae': mae(model, trl_eval, dev, ymu, ysd), 'val_mae': vm, 'test_mae': mae(model, tel, dev, ymu, ysd)} print(f"ep{ep+1} val_mae={vm:.4f}", flush=True) tag = f"{args.conv}_L{args.layers}_rni{args.rni}_s{args.seed}" rep = {'dataset': 'Peptides-struct', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1), 'dev': dev, **best} print(f"[{tag}] train_mae={best.get('train_mae'):.4f} val_mae={best.get('val_mae'):.4f} " f"test_mae={best.get('test_mae'):.4f} @ep{best.get('ep')} ({rep['sec']}s)") fn = os.path.join(OUT, f"real_{tag}.json") with open(fn, 'w') as f: json.dump(rep, f, indent=2) print(" wrote", fn) if __name__ == "__main__": main()