diff options
Diffstat (limited to 'diag/ppgn_color.py')
| -rw-r--r-- | diag/ppgn_color.py | 209 |
1 files changed, 209 insertions, 0 deletions
diff --git a/diag/ppgn_color.py b/diag/ppgn_color.py new file mode 100644 index 0000000..25db01b --- /dev/null +++ b/diag/ppgn_color.py @@ -0,0 +1,209 @@ +"""H2: Recursive PPGN (higher-order, 3-WL) backbone on graph 3-coloring. + +State is a pair-tensor X[i,j,:] (not node features). The powerful block multiplies two +channel-wise n x n matrices (the 3-WL operation): P = M1 @ M2; out = m3([X, P]). Recurse the +pair-tensor (TRM-style deep supervision), pool to nodes (mean over j), decode colors. +Self-contained: trains (EMA, best solve), then LE diagnostic + PTRM noise/lambda-select; +writes color_/le_/ptrm_ JSONs (conv='ppgn') so diag/aggregate.py folds it into the big table. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ppgn_color.py --grad_mode full --seed 0 +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diag.train_color import make_split, featurize, OUT +try: + from sklearn.metrics import roc_auc_score +except Exception: + roc_auc_score = None + + +def dense_A(edge_index, n): + A = torch.zeros(n, n) + if edge_index.shape[1]: + A[edge_index[0], edge_index[1]] = 1.0 + return torch.maximum(A, A.t()) + + +def mlp(di, dh, do): + return nn.Sequential(nn.Linear(di, dh), nn.ReLU(), nn.Linear(dh, do)) + + +class RecPPGN(nn.Module): + def __init__(self, in_dim, hidden=64, k=3, T=3, n_sup=3, grad_mode='full', sigma=0.0): + super().__init__() + self.node_in = nn.Linear(in_dim, hidden) + self.adj_emb = nn.Embedding(2, hidden) + self.m1 = mlp(hidden, hidden, hidden); self.m2 = mlp(hidden, hidden, hidden) + self.m3 = mlp(2 * hidden, hidden, hidden) + self.ln = nn.LayerNorm(hidden) + self.head = nn.Linear(hidden, k) + self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma + + def init_pair(self, xin, A): # xin [B,n,in], A [B,n,n] + h = self.node_in(xin) + X = h.unsqueeze(2) + h.unsqueeze(1) # [B,n,n,hidden] + return X + self.adj_emb(A.long()) + + def block(self, X): # X [B,n,n,h] + M1, M2 = self.m1(X), self.m2(X) + P = torch.einsum('bikc,bkjc->bijc', M1, M2) / X.shape[1] + return self.ln(self.m3(torch.cat([X, P], dim=-1))) + + def _inner(self, X, X0, noise): + X = self.block(X + X0) + if noise and self.sigma > 0: + X = X + self.sigma * torch.randn_like(X) + return X + + def recurse(self, X, X0, noise, one_step=False): + if one_step: + with torch.no_grad(): + for _ in range(self.T - 1): + X = self._inner(X, X0, noise) + X = X.detach() + return self._inner(X, X0, noise) + for _ in range(self.T): + X = self._inner(X, X0, noise) + return X + + def forward(self, xin, A, noise=False): + X0 = self.init_pair(xin, A) + X = torch.zeros_like(X0) + outs = [] + for s in range(self.n_sup): + X = self.recurse(X, X0, noise, one_step=(self.grad_mode == '1step')) + outs.append(self.head(X.mean(2))) # pool over j -> [B,n,k] + X = X.detach() + return outs + + +def conflict_loss(logits, A): # logits [B,n,k], A [B,n,n] + p = F.softmax(logits, dim=-1) + return (torch.einsum('bik,bjk->bij', p, p) * A).sum() / (A.sum() + 1e-9) + + +def batches(graphs, bs, shuffle, dev): + idx = np.arange(len(graphs)) + if shuffle: + np.random.shuffle(idx) + for s in range(0, len(idx) - (len(idx) % bs if shuffle else 0), bs): + sel = idx[s:s + bs] + if len(sel) == 0: + continue + X = torch.stack([graphs[i]['xin'] for i in sel]).to(dev) + A = torch.stack([graphs[i]['A'] for i in sel]).to(dev) + yield X, A + + +@torch.no_grad() +def solve_stats(model, graphs, dev, sample=300): + model.eval(); solved = 0; tot = 0 + for g in graphs[:sample]: + A = g['A'].unsqueeze(0).to(dev) + col = model(g['xin'].unsqueeze(0).to(dev), A)[-1][0].argmax(-1) + conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2 + solved += int(conf == 0); tot += 1 + return solved / tot + + +def lyap1_and_solved(model, g, dev, seed, sigma=0.0): + gen = torch.Generator(device=dev).manual_seed(seed) + xin = g['xin'].unsqueeze(0).to(dev); A = g['A'].unsqueeze(0).to(dev) + X0 = model.init_pair(xin, A) + X = torch.zeros_like(X0) + v = torch.randn(X.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12) + lam = 0.0 + for _ in range(model.n_sup * model.T): + Xd, Jv = torch.autograd.functional.jvp(lambda XX: model.block(XX + X0), X, v) + nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach() + X = Xd.detach() + if sigma > 0: + X = X + sigma * torch.randn(X.shape, generator=gen, device=dev) + col = model.head(X.mean(2))[0].argmax(-1) + conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2 + return lam / (model.n_sup * model.T), conf + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--bs', type=int, default=16) + ap.add_argument('--hidden', type=int, default=64); ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--sigma', type=float, default=0.2); ap.add_argument('--K', type=int, default=16) + ap.add_argument('--n_graphs', type=int, default=150) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16) + te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16) + for g in tr + te: + g['A'] = dense_A(g['edge_index'], g['n']) + in_dim = tr[0]['xin'].shape[1] + + model = RecPPGN(in_dim, args.hidden, 3, grad_mode=args.grad_mode).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + t0 = time.time(); best = -1; best_state = None + for ep in range(args.epochs): + model.train() + for X, A in batches(tr, args.bs, True, dev): + opt.zero_grad() + outs = model(X, A, noise=False) + loss = sum(conflict_loss(o, A) for o in outs) / len(outs) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + with torch.no_grad(): + for kk, v in model.state_dict().items(): + ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v) + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + model.load_state_dict(ema); sr = solve_stats(model, te, dev) + if sr > best: + best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema} + model.load_state_dict(bk) + print(f"ep{ep+1} solve={sr:.3f}", flush=True) + model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval() + + # LE + PTRM (noise + lambda-select) on test + lams, fails = [], [] + passk = lamsel = rand = 0 + Lr, Sr = [], [] + for gi, g in enumerate(te[:args.n_graphs]): + lam0, c0 = lyap1_and_solved(model, g, dev, seed=gi, sigma=0.0) + lams.append(lam0); fails.append(int(c0 > 0)) + res = [lyap1_and_solved(model, g, dev, seed=1000 * gi + j, sigma=args.sigma) for j in range(args.K)] + confs = np.array([c for _, c in res]); rl = np.array([l for l, _ in res]) + solved = confs == 0 + passk += int(solved.any()); lamsel += int(solved[rl.argmin()]); rand += int(solved[0]) + Lr += rl.tolist(); Sr += solved.tolist() + lams, fails = np.array(lams), np.array(fails) + s, f = lams[fails == 0], lams[fails == 1] + auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) + Lr, Sr = np.array(Lr), np.array(Sr) + pauc = (roc_auc_score(Sr.astype(int), -Lr) if roc_auc_score and Sr.any() and (~Sr).any() else float('nan')) + n = len(te[:args.n_graphs]) + print(f"[ppgn/{args.grad_mode}] solve={best:.3f} | LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} | " + f"PTRM det={1 - fails.mean():.3f} passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)") + + base = f"ppgn_full_none_n50_k3_p0.2_T3_ns3_s{args.seed}" if args.grad_mode == 'full' \ + else f"ppgn_1step_none_n50_k3_p0.2_T3_ns3_s{args.seed}" + com = {'conv': 'ppgn', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed} + json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'), indent=2) + json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean()), + 'lam_solved': (float(s.mean()) if len(s) else None), + 'lam_unsolved': (float(f.mean()) if len(f) else None)}, + open(os.path.join(OUT, f"le_color_{base}.json"), 'w'), indent=2) + json.dump({**com, 'det': 1 - float(fails.mean()), + 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n, + 'perRoll': float(Sr.mean()), 'auroc': float(pauc)}}}, + open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'), indent=2) + print(" wrote color_/le_color_/ptrm_color_", base) + + +if __name__ == "__main__": + main() |
