summaryrefslogtreecommitdiff
path: root/diag/train_diag.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/train_diag.py')
-rw-r--r--diag/train_diag.py161
1 files changed, 161 insertions, 0 deletions
diff --git a/diag/train_diag.py b/diag/train_diag.py
new file mode 100644
index 0000000..a9c4ab2
--- /dev/null
+++ b/diag/train_diag.py
@@ -0,0 +1,161 @@
+"""Train a backbone, collect failures, attribute them via the 1-WL instrument.
+
+Node features are CONSTANT (all-ones) so the GIN starts from anonymous nodes -> its
+expressivity ceiling is exactly the anonymous 1-WL partition the instrument computes
+(wl_refine init = all-zero). GIN depth L == L WL rounds. Regression targets are
+standardized (train stats) for stable training; all reported MSEs are in original units.
+Train AND test metrics are reported so non-H2 error can be split into optimization
+(can't even fit train) vs generalization (fits train, fails test).
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_diag.py --task csl --model gin
+"""
+import argparse, json, os, time
+from collections import Counter
+import numpy as np
+import torch
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from diag import wl, datasets as DS, models as M
+
+
+def to_pyg(raw, task, ymu=0.0, ysd=1.0):
+ out = []
+ for d in raw:
+ x = torch.ones(d['n'], 1)
+ ei = torch.tensor(d['edge_index'], dtype=torch.long)
+ if task == 'clf':
+ y = torch.tensor([d['y']], dtype=torch.long)
+ else:
+ y = torch.tensor([[(d['y'] - ymu) / ysd]], dtype=torch.float)
+ out.append(Data(x=x, edge_index=ei, y=y, num_nodes=d['n']))
+ return out
+
+
+def split(n, frac, seed, y=None, stratify=False):
+ rng = np.random.default_rng(seed)
+ idx = np.arange(n)
+ if stratify and y is not None:
+ y = np.asarray(y); test = []
+ for c in np.unique(y):
+ ci = idx[y == c]; rng.shuffle(ci)
+ test += ci[:max(1, int(round(frac * len(ci))))].tolist()
+ test = sorted(set(test)); train = [i for i in idx.tolist() if i not in set(test)]
+ else:
+ rng.shuffle(idx); k = int(frac * n)
+ test = sorted(idx[:k].tolist()); train = sorted(idx[k:].tolist())
+ return train, test
+
+
+@torch.no_grad()
+def predict(model, loader, task, dev):
+ model.eval(); outs = []
+ for b in loader:
+ b = b.to(dev)
+ o = model(b.x, b.edge_index, b.batch)
+ outs.append((o.argmax(1) if task == 'clf' else o.view(-1)).cpu())
+ return torch.cat(outs).numpy()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--task', choices=['csl', 'tri'], required=True)
+ ap.add_argument('--model', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--layers', type=int, default=4)
+ ap.add_argument('--hidden', type=int, default=64)
+ ap.add_argument('--epochs', type=int, default=300)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--seed', type=int, default=0)
+ ap.add_argument('--kind', default='er')
+ ap.add_argument('--out', default='/home/yurenh2/rrog/runs')
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(args.out, exist_ok=True)
+
+ if args.task == 'csl':
+ raw = DS.build_csl(n_per_class=15, seed=args.seed); task, out_dim = 'clf', 10
+ y = [d['y'] for d in raw]; tr, te = split(len(raw), 0.34, args.seed, y, stratify=True)
+ else:
+ raw = DS.build_triangle_count(n_graphs=800, n_nodes=18, kind=args.kind, deg=3, seed=args.seed)
+ task, out_dim = 'reg', 1; tr, te = split(len(raw), 0.3, args.seed)
+
+ ymu, ysd = 0.0, 1.0
+ if task == 'reg':
+ ytr = np.array([raw[i]['y'] for i in tr], dtype=np.float64)
+ ymu, ysd = float(ytr.mean()), float(ytr.std() + 1e-8)
+
+ pyg = to_pyg(raw, task, ymu, ysd)
+ trl = DataLoader([pyg[i] for i in tr], batch_size=32, shuffle=True, drop_last=True)
+ alll = DataLoader(pyg, batch_size=64)
+
+ Model = M.GIN if args.model == 'gin' else M.GCN
+ model = Model(in_dim=1, hidden=args.hidden, layers=args.layers, out_dim=out_dim).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3)
+ lossf = torch.nn.CrossEntropyLoss() if task == 'clf' else torch.nn.MSELoss()
+
+ t0 = time.time()
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ o = model(b.x, b.edge_index, b.batch)
+ loss = lossf(o, b.y)
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ pred = predict(model, alll, task, dev)
+ if task == 'reg':
+ pred = pred * ysd + ymu
+ yv = np.array([d['y'] for d in raw], dtype=(np.float64 if task == 'reg' else np.int64))
+ adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in raw]
+ _, ghist, conv = wl.wl_refine(adjs)
+
+ rep = {'task': args.task, 'model': args.model, 'layers': args.layers, 'seed': args.seed,
+ 'kind': (args.kind if args.task == 'tri' else None),
+ 'n': len(raw), 'n_train': len(tr), 'n_test': len(te), 'conv_round': conv,
+ 'sec': round(time.time() - t0, 1), 'dev': dev}
+
+ if task == 'clf':
+ test_pred, test_y = pred[te], yv[te]
+ acc = float((test_pred == test_y).mean())
+ train_acc = float((pred[tr] == yv[tr]).mean())
+ att = wl.attribute_classification(ghist, conv, args.layers, yv, tr, te)
+ fails = [te[k] for k in range(len(te)) if test_pred[k] != test_y[k]]
+ fb = Counter(att['buckets'][i] for i in fails)
+ rep.update({'train_acc': round(train_acc, 4), 'test_acc': round(acc, 4),
+ 'wl_ceiling_acc_converged': round(att['wl_optimal_acc_converged'], 4),
+ 'wl_ceiling_acc_Ldepth': round(att['wl_optimal_acc_Ldepth'], 4),
+ 'test_bucket_counts': att['counts'],
+ 'failure_bucket_counts': dict(fb), 'n_failures': len(fails)})
+ print(f"[{args.task}/{args.model}] train_acc={train_acc:.3f} test_acc={acc:.3f} | "
+ f"1-WL ceiling(conv)={att['wl_optimal_acc_converged']:.3f} "
+ f"L-depth={att['wl_optimal_acc_Ldepth']:.3f} | failures={len(fails)} -> {dict(fb)}")
+ else:
+ test_pred, test_y = pred[te], yv[te]
+ mse = float(((test_pred - test_y) ** 2).mean())
+ train_mse = float(((pred[tr] - yv[tr]) ** 2).mean())
+ dec = wl.decompose_regression(ghist, conv, args.layers, yv, tr, te)
+ h2 = dec['mse_floor_oracle_H2']
+ rep.update({'train_mse': round(train_mse, 4), 'test_mse_gin': round(mse, 4),
+ 'mse_floor_oracle_H2': round(h2, 4),
+ 'mse_floor_converged_train': round(dec['mse_floor_converged_train'], 4),
+ 'mse_floor_Ldepth_train': round(dec['mse_floor_Ldepth_train'], 4),
+ 'var_target_test': round(dec['var_target_eval'], 4),
+ 'frac_test_unseen_color': round(dec['frac_test_unseen_color'], 4),
+ 'frac_test_singleton_color': round(dec['frac_test_singleton_color'], 4),
+ 'learn_gap_test': round(max(0.0, mse - h2), 4)})
+ print(f"[{args.task}/{args.model}/{args.kind}] train_mse={train_mse:.3f} test_mse={mse:.3f} | "
+ f"1-WL oracle floor(H2)={h2:.3f} | unseen={dec['frac_test_unseen_color']:.2f} "
+ f"singleton={dec['frac_test_singleton_color']:.2f} | learn_gap={max(0.0, mse - h2):.3f} "
+ f"var_y={dec['var_target_eval']:.3f}")
+
+ tag = f"{args.task}_{args.kind}" if args.task == 'tri' else args.task
+ fn = os.path.join(args.out, f"diag_{tag}_{args.model}_L{args.layers}_s{args.seed}.json")
+ with open(fn, 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", fn)
+
+
+if __name__ == "__main__":
+ main()