diff options
Diffstat (limited to 'diag/train_real.py')
| -rw-r--r-- | diag/train_real.py | 139 |
1 files changed, 139 insertions, 0 deletions
diff --git a/diag/train_real.py b/diag/train_real.py new file mode 100644 index 0000000..336d86f --- /dev/null +++ b/diag/train_real.py @@ -0,0 +1,139 @@ +"""Training-based failure diagnosis on LRGB Peptides-struct (real, large, long-range). + +The WL partition instrument is vacuous here (graphs ~all distinguishable), so we diagnose +by TRAINING and comparing: + GIN(L) : standard 1-WL backbone at depth L + GIN(L)+RNI : random node features = noise = beyond-1-WL symmetry breaker + GCN(L) : sub-1-WL reference +Reads: deeper helps -> long-range/under-reaching; RNI helps -> a real >1-WL ceiling that +noise breaks; train<<test -> generalization; train high -> compute/optimization ceiling. +Targets z-scored per dim; metric = standardized MAE (lower better). 11 targets. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_real.py --conv gin --layers 5 --rni 0 +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +from torch_geometric.datasets import LRGBDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn import GINConv, GCNConv, global_mean_pool + +ROOT = '/home/yurenh2/rrog/data/lrgb' +OUT = '/home/yurenh2/rrog/runs' + + +class Net(nn.Module): + def __init__(self, col_sizes, hidden, layers, out_dim, conv='gin', rni=0): + super().__init__() + self.embs = nn.ModuleList([nn.Embedding(int(s), hidden) for s in col_sizes]) + self.rni = rni + self.lin_in = nn.Linear(hidden + rni, hidden) + self.convs, self.bns = nn.ModuleList(), nn.ModuleList() + for _ in range(layers): + if conv == 'gin': + mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)) + self.convs.append(GINConv(mlp, train_eps=True)) + else: + self.convs.append(GCNConv(hidden, hidden)) + self.bns.append(nn.BatchNorm1d(hidden)) + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_dim)) + + def forward(self, x, edge_index, batch): + h = sum(emb(x[:, i]) for i, emb in enumerate(self.embs)) + if self.rni: + h = torch.cat([h, torch.randn(h.size(0), self.rni, device=h.device)], dim=1) + h = self.lin_in(h) + for conv, bn in zip(self.convs, self.bns): + h = bn(conv(h, edge_index)).relu() + return self.head(global_mean_pool(h, batch)) + + +@torch.no_grad() +def mae(model, loader, dev, ymu, ysd): + model.eval(); se = n = 0.0 + for b in loader: + b = b.to(dev) + o = model(b.x, b.edge_index, b.batch) + se += (o - b.y).abs().sum().item(); n += b.y.numel() + return se / n + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin') + ap.add_argument('--layers', type=int, default=5) + ap.add_argument('--hidden', type=int, default=128) + ap.add_argument('--rni', type=int, default=0) + ap.add_argument('--epochs', type=int, default=150) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--bs', type=int, default=128) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(OUT, exist_ok=True) + + tr = LRGBDataset(root=ROOT, name='Peptides-struct', split='train') + va = LRGBDataset(root=ROOT, name='Peptides-struct', split='val') + te = LRGBDataset(root=ROOT, name='Peptides-struct', split='test') + + # per-column embedding sizes + target standardization (train stats) + col_max = None + Ytr = [] + for g in tr: + m = g.x.max(0).values + col_max = m if col_max is None else torch.maximum(col_max, m) + Ytr.append(g.y.view(-1)) + for ds in (va, te): + for g in ds: + col_max = torch.maximum(col_max, g.x.max(0).values) + col_sizes = (col_max + 2).tolist() + Ytr = torch.stack(Ytr) + ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 + + def norm(ds): + out = [] + for g in ds: + g = g.clone(); g.y = (g.y.view(1, -1) - ymu) / ysd + out.append(g) + return out + trl = DataLoader(norm(tr), batch_size=args.bs, shuffle=True, drop_last=True) + val = DataLoader(norm(va), batch_size=256) + tel = DataLoader(norm(te), batch_size=256) + trl_eval = DataLoader(norm(tr), batch_size=256) + + model = Net(col_sizes, args.hidden, args.layers, out_dim=11, conv=args.conv, rni=args.rni).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + lossf = nn.L1Loss() + + t0 = time.time(); best_val = 9e9; best = {} + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + loss = lossf(model(b.x, b.edge_index, b.batch), b.y) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + sched.step() + if (ep + 1) % 15 == 0 or ep == args.epochs - 1: + vm = mae(model, val, dev, ymu, ysd) + if vm < best_val: + best_val = vm + best = {'ep': ep + 1, 'train_mae': mae(model, trl_eval, dev, ymu, ysd), + 'val_mae': vm, 'test_mae': mae(model, tel, dev, ymu, ysd)} + print(f"ep{ep+1} val_mae={vm:.4f}", flush=True) + + tag = f"{args.conv}_L{args.layers}_rni{args.rni}_s{args.seed}" + rep = {'dataset': 'Peptides-struct', 'tag': tag, **vars(args), + 'sec': round(time.time() - t0, 1), 'dev': dev, **best} + print(f"[{tag}] train_mae={best.get('train_mae'):.4f} val_mae={best.get('val_mae'):.4f} " + f"test_mae={best.get('test_mae'):.4f} @ep{best.get('ep')} ({rep['sec']}s)") + fn = os.path.join(OUT, f"real_{tag}.json") + with open(fn, 'w') as f: + json.dump(rep, f, indent=2) + print(" wrote", fn) + + +if __name__ == "__main__": + main() |
