diff options
Diffstat (limited to 'diag/train_diag.py')
| -rw-r--r-- | diag/train_diag.py | 161 |
1 files changed, 161 insertions, 0 deletions
diff --git a/diag/train_diag.py b/diag/train_diag.py new file mode 100644 index 0000000..a9c4ab2 --- /dev/null +++ b/diag/train_diag.py @@ -0,0 +1,161 @@ +"""Train a backbone, collect failures, attribute them via the 1-WL instrument. + +Node features are CONSTANT (all-ones) so the GIN starts from anonymous nodes -> its +expressivity ceiling is exactly the anonymous 1-WL partition the instrument computes +(wl_refine init = all-zero). GIN depth L == L WL rounds. Regression targets are +standardized (train stats) for stable training; all reported MSEs are in original units. +Train AND test metrics are reported so non-H2 error can be split into optimization +(can't even fit train) vs generalization (fits train, fails test). + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_diag.py --task csl --model gin +""" +import argparse, json, os, time +from collections import Counter +import numpy as np +import torch +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from diag import wl, datasets as DS, models as M + + +def to_pyg(raw, task, ymu=0.0, ysd=1.0): + out = [] + for d in raw: + x = torch.ones(d['n'], 1) + ei = torch.tensor(d['edge_index'], dtype=torch.long) + if task == 'clf': + y = torch.tensor([d['y']], dtype=torch.long) + else: + y = torch.tensor([[(d['y'] - ymu) / ysd]], dtype=torch.float) + out.append(Data(x=x, edge_index=ei, y=y, num_nodes=d['n'])) + return out + + +def split(n, frac, seed, y=None, stratify=False): + rng = np.random.default_rng(seed) + idx = np.arange(n) + if stratify and y is not None: + y = np.asarray(y); test = [] + for c in np.unique(y): + ci = idx[y == c]; rng.shuffle(ci) + test += ci[:max(1, int(round(frac * len(ci))))].tolist() + test = sorted(set(test)); train = [i for i in idx.tolist() if i not in set(test)] + else: + rng.shuffle(idx); k = int(frac * n) + test = sorted(idx[:k].tolist()); train = sorted(idx[k:].tolist()) + return train, test + + +@torch.no_grad() +def predict(model, loader, task, dev): + model.eval(); outs = [] + for b in loader: + b = b.to(dev) + o = model(b.x, b.edge_index, b.batch) + outs.append((o.argmax(1) if task == 'clf' else o.view(-1)).cpu()) + return torch.cat(outs).numpy() + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--task', choices=['csl', 'tri'], required=True) + ap.add_argument('--model', choices=['gin', 'gcn'], default='gin') + ap.add_argument('--layers', type=int, default=4) + ap.add_argument('--hidden', type=int, default=64) + ap.add_argument('--epochs', type=int, default=300) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--kind', default='er') + ap.add_argument('--out', default='/home/yurenh2/rrog/runs') + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(args.out, exist_ok=True) + + if args.task == 'csl': + raw = DS.build_csl(n_per_class=15, seed=args.seed); task, out_dim = 'clf', 10 + y = [d['y'] for d in raw]; tr, te = split(len(raw), 0.34, args.seed, y, stratify=True) + else: + raw = DS.build_triangle_count(n_graphs=800, n_nodes=18, kind=args.kind, deg=3, seed=args.seed) + task, out_dim = 'reg', 1; tr, te = split(len(raw), 0.3, args.seed) + + ymu, ysd = 0.0, 1.0 + if task == 'reg': + ytr = np.array([raw[i]['y'] for i in tr], dtype=np.float64) + ymu, ysd = float(ytr.mean()), float(ytr.std() + 1e-8) + + pyg = to_pyg(raw, task, ymu, ysd) + trl = DataLoader([pyg[i] for i in tr], batch_size=32, shuffle=True, drop_last=True) + alll = DataLoader(pyg, batch_size=64) + + Model = M.GIN if args.model == 'gin' else M.GCN + model = Model(in_dim=1, hidden=args.hidden, layers=args.layers, out_dim=out_dim).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3) + lossf = torch.nn.CrossEntropyLoss() if task == 'clf' else torch.nn.MSELoss() + + t0 = time.time() + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + o = model(b.x, b.edge_index, b.batch) + loss = lossf(o, b.y) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + + pred = predict(model, alll, task, dev) + if task == 'reg': + pred = pred * ysd + ymu + yv = np.array([d['y'] for d in raw], dtype=(np.float64 if task == 'reg' else np.int64)) + adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in raw] + _, ghist, conv = wl.wl_refine(adjs) + + rep = {'task': args.task, 'model': args.model, 'layers': args.layers, 'seed': args.seed, + 'kind': (args.kind if args.task == 'tri' else None), + 'n': len(raw), 'n_train': len(tr), 'n_test': len(te), 'conv_round': conv, + 'sec': round(time.time() - t0, 1), 'dev': dev} + + if task == 'clf': + test_pred, test_y = pred[te], yv[te] + acc = float((test_pred == test_y).mean()) + train_acc = float((pred[tr] == yv[tr]).mean()) + att = wl.attribute_classification(ghist, conv, args.layers, yv, tr, te) + fails = [te[k] for k in range(len(te)) if test_pred[k] != test_y[k]] + fb = Counter(att['buckets'][i] for i in fails) + rep.update({'train_acc': round(train_acc, 4), 'test_acc': round(acc, 4), + 'wl_ceiling_acc_converged': round(att['wl_optimal_acc_converged'], 4), + 'wl_ceiling_acc_Ldepth': round(att['wl_optimal_acc_Ldepth'], 4), + 'test_bucket_counts': att['counts'], + 'failure_bucket_counts': dict(fb), 'n_failures': len(fails)}) + print(f"[{args.task}/{args.model}] train_acc={train_acc:.3f} test_acc={acc:.3f} | " + f"1-WL ceiling(conv)={att['wl_optimal_acc_converged']:.3f} " + f"L-depth={att['wl_optimal_acc_Ldepth']:.3f} | failures={len(fails)} -> {dict(fb)}") + else: + test_pred, test_y = pred[te], yv[te] + mse = float(((test_pred - test_y) ** 2).mean()) + train_mse = float(((pred[tr] - yv[tr]) ** 2).mean()) + dec = wl.decompose_regression(ghist, conv, args.layers, yv, tr, te) + h2 = dec['mse_floor_oracle_H2'] + rep.update({'train_mse': round(train_mse, 4), 'test_mse_gin': round(mse, 4), + 'mse_floor_oracle_H2': round(h2, 4), + 'mse_floor_converged_train': round(dec['mse_floor_converged_train'], 4), + 'mse_floor_Ldepth_train': round(dec['mse_floor_Ldepth_train'], 4), + 'var_target_test': round(dec['var_target_eval'], 4), + 'frac_test_unseen_color': round(dec['frac_test_unseen_color'], 4), + 'frac_test_singleton_color': round(dec['frac_test_singleton_color'], 4), + 'learn_gap_test': round(max(0.0, mse - h2), 4)}) + print(f"[{args.task}/{args.model}/{args.kind}] train_mse={train_mse:.3f} test_mse={mse:.3f} | " + f"1-WL oracle floor(H2)={h2:.3f} | unseen={dec['frac_test_unseen_color']:.2f} " + f"singleton={dec['frac_test_singleton_color']:.2f} | learn_gap={max(0.0, mse - h2):.3f} " + f"var_y={dec['var_target_eval']:.3f}") + + tag = f"{args.task}_{args.kind}" if args.task == 'tri' else args.task + fn = os.path.join(args.out, f"diag_{tag}_{args.model}_L{args.layers}_s{args.seed}.json") + with open(fn, 'w') as f: + json.dump(rep, f, indent=2) + print(" wrote", fn) + + +if __name__ == "__main__": + main() |
