"""Train a backbone, collect failures, attribute them via the 1-WL instrument. Node features are CONSTANT (all-ones) so the GIN starts from anonymous nodes -> its expressivity ceiling is exactly the anonymous 1-WL partition the instrument computes (wl_refine init = all-zero). GIN depth L == L WL rounds. Regression targets are standardized (train stats) for stable training; all reported MSEs are in original units. Train AND test metrics are reported so non-H2 error can be split into optimization (can't even fit train) vs generalization (fits train, fails test). Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_diag.py --task csl --model gin """ import argparse, json, os, time from collections import Counter import numpy as np import torch from torch_geometric.data import Data from torch_geometric.loader import DataLoader from diag import wl, datasets as DS, models as M def to_pyg(raw, task, ymu=0.0, ysd=1.0): out = [] for d in raw: x = torch.ones(d['n'], 1) ei = torch.tensor(d['edge_index'], dtype=torch.long) if task == 'clf': y = torch.tensor([d['y']], dtype=torch.long) else: y = torch.tensor([[(d['y'] - ymu) / ysd]], dtype=torch.float) out.append(Data(x=x, edge_index=ei, y=y, num_nodes=d['n'])) return out def split(n, frac, seed, y=None, stratify=False): rng = np.random.default_rng(seed) idx = np.arange(n) if stratify and y is not None: y = np.asarray(y); test = [] for c in np.unique(y): ci = idx[y == c]; rng.shuffle(ci) test += ci[:max(1, int(round(frac * len(ci))))].tolist() test = sorted(set(test)); train = [i for i in idx.tolist() if i not in set(test)] else: rng.shuffle(idx); k = int(frac * n) test = sorted(idx[:k].tolist()); train = sorted(idx[k:].tolist()) return train, test @torch.no_grad() def predict(model, loader, task, dev): model.eval(); outs = [] for b in loader: b = b.to(dev) o = model(b.x, b.edge_index, b.batch) outs.append((o.argmax(1) if task == 'clf' else o.view(-1)).cpu()) return torch.cat(outs).numpy() def main(): ap = argparse.ArgumentParser() ap.add_argument('--task', choices=['csl', 'tri'], required=True) ap.add_argument('--model', choices=['gin', 'gcn'], default='gin') ap.add_argument('--layers', type=int, default=4) ap.add_argument('--hidden', type=int, default=64) ap.add_argument('--epochs', type=int, default=300) ap.add_argument('--lr', type=float, default=1e-3) ap.add_argument('--seed', type=int, default=0) ap.add_argument('--kind', default='er') ap.add_argument('--out', default='/home/yurenh2/rrog/runs') args = ap.parse_args() torch.manual_seed(args.seed); np.random.seed(args.seed) dev = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(args.out, exist_ok=True) if args.task == 'csl': raw = DS.build_csl(n_per_class=15, seed=args.seed); task, out_dim = 'clf', 10 y = [d['y'] for d in raw]; tr, te = split(len(raw), 0.34, args.seed, y, stratify=True) else: raw = DS.build_triangle_count(n_graphs=800, n_nodes=18, kind=args.kind, deg=3, seed=args.seed) task, out_dim = 'reg', 1; tr, te = split(len(raw), 0.3, args.seed) ymu, ysd = 0.0, 1.0 if task == 'reg': ytr = np.array([raw[i]['y'] for i in tr], dtype=np.float64) ymu, ysd = float(ytr.mean()), float(ytr.std() + 1e-8) pyg = to_pyg(raw, task, ymu, ysd) trl = DataLoader([pyg[i] for i in tr], batch_size=32, shuffle=True, drop_last=True) alll = DataLoader(pyg, batch_size=64) Model = M.GIN if args.model == 'gin' else M.GCN model = Model(in_dim=1, hidden=args.hidden, layers=args.layers, out_dim=out_dim).to(dev) opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3) lossf = torch.nn.CrossEntropyLoss() if task == 'clf' else torch.nn.MSELoss() t0 = time.time() for ep in range(args.epochs): model.train() for b in trl: b = b.to(dev); opt.zero_grad() o = model(b.x, b.edge_index, b.batch) loss = lossf(o, b.y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() pred = predict(model, alll, task, dev) if task == 'reg': pred = pred * ysd + ymu yv = np.array([d['y'] for d in raw], dtype=(np.float64 if task == 'reg' else np.int64)) adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in raw] _, ghist, conv = wl.wl_refine(adjs) rep = {'task': args.task, 'model': args.model, 'layers': args.layers, 'seed': args.seed, 'kind': (args.kind if args.task == 'tri' else None), 'n': len(raw), 'n_train': len(tr), 'n_test': len(te), 'conv_round': conv, 'sec': round(time.time() - t0, 1), 'dev': dev} if task == 'clf': test_pred, test_y = pred[te], yv[te] acc = float((test_pred == test_y).mean()) train_acc = float((pred[tr] == yv[tr]).mean()) att = wl.attribute_classification(ghist, conv, args.layers, yv, tr, te) fails = [te[k] for k in range(len(te)) if test_pred[k] != test_y[k]] fb = Counter(att['buckets'][i] for i in fails) rep.update({'train_acc': round(train_acc, 4), 'test_acc': round(acc, 4), 'wl_ceiling_acc_converged': round(att['wl_optimal_acc_converged'], 4), 'wl_ceiling_acc_Ldepth': round(att['wl_optimal_acc_Ldepth'], 4), 'test_bucket_counts': att['counts'], 'failure_bucket_counts': dict(fb), 'n_failures': len(fails)}) print(f"[{args.task}/{args.model}] train_acc={train_acc:.3f} test_acc={acc:.3f} | " f"1-WL ceiling(conv)={att['wl_optimal_acc_converged']:.3f} " f"L-depth={att['wl_optimal_acc_Ldepth']:.3f} | failures={len(fails)} -> {dict(fb)}") else: test_pred, test_y = pred[te], yv[te] mse = float(((test_pred - test_y) ** 2).mean()) train_mse = float(((pred[tr] - yv[tr]) ** 2).mean()) dec = wl.decompose_regression(ghist, conv, args.layers, yv, tr, te) h2 = dec['mse_floor_oracle_H2'] rep.update({'train_mse': round(train_mse, 4), 'test_mse_gin': round(mse, 4), 'mse_floor_oracle_H2': round(h2, 4), 'mse_floor_converged_train': round(dec['mse_floor_converged_train'], 4), 'mse_floor_Ldepth_train': round(dec['mse_floor_Ldepth_train'], 4), 'var_target_test': round(dec['var_target_eval'], 4), 'frac_test_unseen_color': round(dec['frac_test_unseen_color'], 4), 'frac_test_singleton_color': round(dec['frac_test_singleton_color'], 4), 'learn_gap_test': round(max(0.0, mse - h2), 4)}) print(f"[{args.task}/{args.model}/{args.kind}] train_mse={train_mse:.3f} test_mse={mse:.3f} | " f"1-WL oracle floor(H2)={h2:.3f} | unseen={dec['frac_test_unseen_color']:.2f} " f"singleton={dec['frac_test_singleton_color']:.2f} | learn_gap={max(0.0, mse - h2):.3f} " f"var_y={dec['var_target_eval']:.3f}") tag = f"{args.task}_{args.kind}" if args.task == 'tri' else args.task fn = os.path.join(args.out, f"diag_{tag}_{args.model}_L{args.layers}_s{args.seed}.json") with open(fn, 'w') as f: json.dump(rep, f, indent=2) print(" wrote", fn) if __name__ == "__main__": main()