summaryrefslogtreecommitdiff
path: root/diag/train_real.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/train_real.py')
-rw-r--r--diag/train_real.py139
1 files changed, 139 insertions, 0 deletions
diff --git a/diag/train_real.py b/diag/train_real.py
new file mode 100644
index 0000000..336d86f
--- /dev/null
+++ b/diag/train_real.py
@@ -0,0 +1,139 @@
+"""Training-based failure diagnosis on LRGB Peptides-struct (real, large, long-range).
+
+The WL partition instrument is vacuous here (graphs ~all distinguishable), so we diagnose
+by TRAINING and comparing:
+ GIN(L) : standard 1-WL backbone at depth L
+ GIN(L)+RNI : random node features = noise = beyond-1-WL symmetry breaker
+ GCN(L) : sub-1-WL reference
+Reads: deeper helps -> long-range/under-reaching; RNI helps -> a real >1-WL ceiling that
+noise breaks; train<<test -> generalization; train high -> compute/optimization ceiling.
+Targets z-scored per dim; metric = standardized MAE (lower better). 11 targets.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_real.py --conv gin --layers 5 --rni 0
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+from torch_geometric.datasets import LRGBDataset
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import GINConv, GCNConv, global_mean_pool
+
+ROOT = '/home/yurenh2/rrog/data/lrgb'
+OUT = '/home/yurenh2/rrog/runs'
+
+
+class Net(nn.Module):
+ def __init__(self, col_sizes, hidden, layers, out_dim, conv='gin', rni=0):
+ super().__init__()
+ self.embs = nn.ModuleList([nn.Embedding(int(s), hidden) for s in col_sizes])
+ self.rni = rni
+ self.lin_in = nn.Linear(hidden + rni, hidden)
+ self.convs, self.bns = nn.ModuleList(), nn.ModuleList()
+ for _ in range(layers):
+ if conv == 'gin':
+ mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden))
+ self.convs.append(GINConv(mlp, train_eps=True))
+ else:
+ self.convs.append(GCNConv(hidden, hidden))
+ self.bns.append(nn.BatchNorm1d(hidden))
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim))
+
+ def forward(self, x, edge_index, batch):
+ h = sum(emb(x[:, i]) for i, emb in enumerate(self.embs))
+ if self.rni:
+ h = torch.cat([h, torch.randn(h.size(0), self.rni, device=h.device)], dim=1)
+ h = self.lin_in(h)
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, edge_index)).relu()
+ return self.head(global_mean_pool(h, batch))
+
+
+@torch.no_grad()
+def mae(model, loader, dev, ymu, ysd):
+ model.eval(); se = n = 0.0
+ for b in loader:
+ b = b.to(dev)
+ o = model(b.x, b.edge_index, b.batch)
+ se += (o - b.y).abs().sum().item(); n += b.y.numel()
+ return se / n
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--layers', type=int, default=5)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--rni', type=int, default=0)
+ ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr = LRGBDataset(root=ROOT, name='Peptides-struct', split='train')
+ va = LRGBDataset(root=ROOT, name='Peptides-struct', split='val')
+ te = LRGBDataset(root=ROOT, name='Peptides-struct', split='test')
+
+ # per-column embedding sizes + target standardization (train stats)
+ col_max = None
+ Ytr = []
+ for g in tr:
+ m = g.x.max(0).values
+ col_max = m if col_max is None else torch.maximum(col_max, m)
+ Ytr.append(g.y.view(-1))
+ for ds in (va, te):
+ for g in ds:
+ col_max = torch.maximum(col_max, g.x.max(0).values)
+ col_sizes = (col_max + 2).tolist()
+ Ytr = torch.stack(Ytr)
+ ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+
+ def norm(ds):
+ out = []
+ for g in ds:
+ g = g.clone(); g.y = (g.y.view(1, -1) - ymu) / ysd
+ out.append(g)
+ return out
+ trl = DataLoader(norm(tr), batch_size=args.bs, shuffle=True, drop_last=True)
+ val = DataLoader(norm(va), batch_size=256)
+ tel = DataLoader(norm(te), batch_size=256)
+ trl_eval = DataLoader(norm(tr), batch_size=256)
+
+ model = Net(col_sizes, args.hidden, args.layers, out_dim=11, conv=args.conv, rni=args.rni).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ lossf = nn.L1Loss()
+
+ t0 = time.time(); best_val = 9e9; best = {}
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = lossf(model(b.x, b.edge_index, b.batch), b.y)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 15 == 0 or ep == args.epochs - 1:
+ vm = mae(model, val, dev, ymu, ysd)
+ if vm < best_val:
+ best_val = vm
+ best = {'ep': ep + 1, 'train_mae': mae(model, trl_eval, dev, ymu, ysd),
+ 'val_mae': vm, 'test_mae': mae(model, tel, dev, ymu, ysd)}
+ print(f"ep{ep+1} val_mae={vm:.4f}", flush=True)
+
+ tag = f"{args.conv}_L{args.layers}_rni{args.rni}_s{args.seed}"
+ rep = {'dataset': 'Peptides-struct', 'tag': tag, **vars(args),
+ 'sec': round(time.time() - t0, 1), 'dev': dev, **best}
+ print(f"[{tag}] train_mae={best.get('train_mae'):.4f} val_mae={best.get('val_mae'):.4f} "
+ f"test_mae={best.get('test_mae'):.4f} @ep{best.get('ep')} ({rep['sec']}s)")
+ fn = os.path.join(OUT, f"real_{tag}.json")
+ with open(fn, 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", fn)
+
+
+if __name__ == "__main__":
+ main()