diff options
Diffstat (limited to 'diag/esan_color.py')
| -rw-r--r-- | diag/esan_color.py | 145 |
1 files changed, 145 insertions, 0 deletions
diff --git a/diag/esan_color.py b/diag/esan_color.py new file mode 100644 index 0000000..7be050c --- /dev/null +++ b/diag/esan_color.py @@ -0,0 +1,145 @@ +"""H3: Recursive ESAN (subgraph GNN, DS-GNN node-marking bag) on graph 3-coloring. + +Per graph, pick S anchor nodes; each view = graph + a 1-hot mark on the anchor. Run the SHARED +recursive GIN on all views, average node-logits over views (DeepSets). Marking breaks node +symmetry -> >1-WL. Self-contained: train (EMA, best solve) + LE (lambda on a marked view, +bucket by aggregate solve) + PTRM (K noisy aggregate forwards); writes color_/le_/ptrm_ JSON +(conv='esan') for diag/aggregate.py. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/esan_color.py --grad_mode full --seed 0 +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn.functional as F +from torch_geometric.data import Data, Batch +from diag.train_color import make_split, featurize, RecGINColor, lyap1, OUT +try: + from sklearn.metrics import roc_auc_score +except Exception: + roc_auc_score = None + +S = 4 + + +def dense_A(edge_index, n): + A = torch.zeros(n, n) + if edge_index.shape[1]: + A[edge_index[0], edge_index[1]] = 1.0 + return torch.maximum(A, A.t()) + + +def views_of(g, anchors): + out = [] + for a in anchors: + mark = torch.zeros(g['n'], 1); mark[a] = 1.0 + out.append(Data(x=torch.cat([g['rfeat'], mark], dim=1), edge_index=g['edge_index'], num_nodes=g['n'])) + return out + + +def anchors_for(g, rng): + return rng.choice(g['n'], size=min(S, g['n']), replace=False) + + +def esan_logits(model, g, dev, anchors, noise=False): + b = Batch.from_data_list(views_of(g, anchors)).to(dev) + out = model(b.x, b.edge_index, b.batch, noise=noise)[-1] # [S*n, k] + return out.view(len(anchors), g['n'], -1).mean(0) # [n, k] + + +def conf_of(logits, A): + col = logits.argmax(-1) + return int(((col.unsqueeze(0) == col.unsqueeze(1)) & (A > 0)).sum().item() // 2) + + +@torch.no_grad() +def solve_rate(model, graphs, dev, rng, sample=300): + model.eval(); solved = 0 + for g in graphs[:sample]: + solved += int(conf_of(esan_logits(model, g, dev, anchors_for(g, rng)), g['A'].to(dev)) == 0) + return solved / len(graphs[:sample]) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--M', type=int, default=16) + ap.add_argument('--seed', type=int, default=0); ap.add_argument('--sigma', type=float, default=0.2) + ap.add_argument('--K', type=int, default=16); ap.add_argument('--n_graphs', type=int, default=150) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + rng = np.random.default_rng(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16) + te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16) + for g in tr + te: + g['A'] = dense_A(g['edge_index'], g['n']) + in_dim = tr[0]['rfeat'].shape[1] + 1 + model = RecGINColor(in_dim, 128, 3, grad_mode=args.grad_mode, conv='gin').to(dev) + opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + + t0 = time.time(); best = -1; best_state = None + for ep in range(args.epochs): + model.train(); order = rng.permutation(len(tr)) + for s0 in range(0, len(order) - args.M, args.M): + sel = order[s0:s0 + args.M] + views, As = [], [] + for i in sel: + g = tr[i]; views += views_of(g, anchors_for(g, rng)); As.append(g['A']) + b = Batch.from_data_list(views).to(dev) + opt.zero_grad() + logits = model(b.x, b.edge_index, b.batch, noise=False)[-1].view(args.M, S, 50, 3).mean(1) + Ab = torch.stack(As).to(dev) + p = F.softmax(logits, -1) + loss = (torch.einsum('bik,bjk->bij', p, p) * Ab).sum() / (Ab.sum() + 1e-9) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + with torch.no_grad(): + for kk, v in model.state_dict().items(): + ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v) + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + model.load_state_dict(ema); sr = solve_rate(model, te, dev, rng) + if sr > best: + best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema} + model.load_state_dict(bk) + print(f"ep{ep+1} solve={sr:.3f}", flush=True) + model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval() + + lams, fails = [], []; passk = lamsel = rand = 0; Lr, Sr = [], [] + nstep = model.n_sup * model.T + for gi, g in enumerate(te[:args.n_graphs]): + anc = anchors_for(g, rng) + mark = torch.zeros(g['n'], 1); mark[anc[0]] = 1.0 + xin = torch.cat([g['rfeat'], mark], dim=1).to(dev); ei = g['edge_index'].to(dev) + lam0 = lyap1(model, xin, ei, nstep, dev, seed=gi) + c0 = conf_of(esan_logits(model, g, dev, anc), g['A'].to(dev)) + lams.append(lam0); fails.append(int(c0 > 0)) + confs, rl = [], [] + for j in range(args.K): + confs.append(conf_of(esan_logits(model, g, dev, anc, noise=True), g['A'].to(dev))) + rl.append(lyap1(model, xin, ei, nstep, dev, seed=1000 * gi + j)) + confs, rl = np.array(confs), np.array(rl); sv = confs == 0 + passk += int(sv.any()); lamsel += int(sv[rl.argmin()]); rand += int(sv[0]) + Lr += rl.tolist(); Sr += sv.tolist() + lams, fails = np.array(lams), np.array(fails); s, f = lams[fails == 0], lams[fails == 1] + auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) + n = len(te[:args.n_graphs]) + print(f"[esan/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} " + f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)") + base = f"esan_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}" + com = {'conv': 'esan', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed} + json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w')) + json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())}, + open(os.path.join(OUT, f"le_color_{base}.json"), 'w')) + json.dump({**com, 'det': 1 - float(fails.mean()), + 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n}}}, + open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w')) + print(" wrote", base) + + +if __name__ == "__main__": + main() |
