"""H2: Recursive PPGN (higher-order, 3-WL) backbone on graph 3-coloring. State is a pair-tensor X[i,j,:] (not node features). The powerful block multiplies two channel-wise n x n matrices (the 3-WL operation): P = M1 @ M2; out = m3([X, P]). Recurse the pair-tensor (TRM-style deep supervision), pool to nodes (mean over j), decode colors. Self-contained: trains (EMA, best solve), then LE diagnostic + PTRM noise/lambda-select; writes color_/le_/ptrm_ JSONs (conv='ppgn') so diag/aggregate.py folds it into the big table. Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ppgn_color.py --grad_mode full --seed 0 """ import argparse, json, os, time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diag.train_color import make_split, featurize, OUT try: from sklearn.metrics import roc_auc_score except Exception: roc_auc_score = None def dense_A(edge_index, n): A = torch.zeros(n, n) if edge_index.shape[1]: A[edge_index[0], edge_index[1]] = 1.0 return torch.maximum(A, A.t()) def mlp(di, dh, do): return nn.Sequential(nn.Linear(di, dh), nn.ReLU(), nn.Linear(dh, do)) class RecPPGN(nn.Module): def __init__(self, in_dim, hidden=64, k=3, T=3, n_sup=3, grad_mode='full', sigma=0.0): super().__init__() self.node_in = nn.Linear(in_dim, hidden) self.adj_emb = nn.Embedding(2, hidden) self.m1 = mlp(hidden, hidden, hidden); self.m2 = mlp(hidden, hidden, hidden) self.m3 = mlp(2 * hidden, hidden, hidden) self.ln = nn.LayerNorm(hidden) self.head = nn.Linear(hidden, k) self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma def init_pair(self, xin, A): # xin [B,n,in], A [B,n,n] h = self.node_in(xin) X = h.unsqueeze(2) + h.unsqueeze(1) # [B,n,n,hidden] return X + self.adj_emb(A.long()) def block(self, X): # X [B,n,n,h] M1, M2 = self.m1(X), self.m2(X) P = torch.einsum('bikc,bkjc->bijc', M1, M2) / X.shape[1] return self.ln(self.m3(torch.cat([X, P], dim=-1))) def _inner(self, X, X0, noise): X = self.block(X + X0) if noise and self.sigma > 0: X = X + self.sigma * torch.randn_like(X) return X def recurse(self, X, X0, noise, one_step=False): if one_step: with torch.no_grad(): for _ in range(self.T - 1): X = self._inner(X, X0, noise) X = X.detach() return self._inner(X, X0, noise) for _ in range(self.T): X = self._inner(X, X0, noise) return X def forward(self, xin, A, noise=False): X0 = self.init_pair(xin, A) X = torch.zeros_like(X0) outs = [] for s in range(self.n_sup): X = self.recurse(X, X0, noise, one_step=(self.grad_mode == '1step')) outs.append(self.head(X.mean(2))) # pool over j -> [B,n,k] X = X.detach() return outs def conflict_loss(logits, A): # logits [B,n,k], A [B,n,n] p = F.softmax(logits, dim=-1) return (torch.einsum('bik,bjk->bij', p, p) * A).sum() / (A.sum() + 1e-9) def batches(graphs, bs, shuffle, dev): idx = np.arange(len(graphs)) if shuffle: np.random.shuffle(idx) for s in range(0, len(idx) - (len(idx) % bs if shuffle else 0), bs): sel = idx[s:s + bs] if len(sel) == 0: continue X = torch.stack([graphs[i]['xin'] for i in sel]).to(dev) A = torch.stack([graphs[i]['A'] for i in sel]).to(dev) yield X, A @torch.no_grad() def solve_stats(model, graphs, dev, sample=300): model.eval(); solved = 0; tot = 0 for g in graphs[:sample]: A = g['A'].unsqueeze(0).to(dev) col = model(g['xin'].unsqueeze(0).to(dev), A)[-1][0].argmax(-1) conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2 solved += int(conf == 0); tot += 1 return solved / tot def lyap1_and_solved(model, g, dev, seed, sigma=0.0): gen = torch.Generator(device=dev).manual_seed(seed) xin = g['xin'].unsqueeze(0).to(dev); A = g['A'].unsqueeze(0).to(dev) X0 = model.init_pair(xin, A) X = torch.zeros_like(X0) v = torch.randn(X.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12) lam = 0.0 for _ in range(model.n_sup * model.T): Xd, Jv = torch.autograd.functional.jvp(lambda XX: model.block(XX + X0), X, v) nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach() X = Xd.detach() if sigma > 0: X = X + sigma * torch.randn(X.shape, generator=gen, device=dev) col = model.head(X.mean(2))[0].argmax(-1) conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2 return lam / (model.n_sup * model.T), conf def main(): ap = argparse.ArgumentParser() ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--bs', type=int, default=16) ap.add_argument('--hidden', type=int, default=64); ap.add_argument('--seed', type=int, default=0) ap.add_argument('--sigma', type=float, default=0.2); ap.add_argument('--K', type=int, default=16) ap.add_argument('--n_graphs', type=int, default=150) args = ap.parse_args() torch.manual_seed(args.seed); np.random.seed(args.seed) dev = 'cuda' if torch.cuda.is_available() else 'cpu' tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16) te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16) for g in tr + te: g['A'] = dense_A(g['edge_index'], g['n']) in_dim = tr[0]['xin'].shape[1] model = RecPPGN(in_dim, args.hidden, 3, grad_mode=args.grad_mode).to(dev) opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} t0 = time.time(); best = -1; best_state = None for ep in range(args.epochs): model.train() for X, A in batches(tr, args.bs, True, dev): opt.zero_grad() outs = model(X, A, noise=False) loss = sum(conflict_loss(o, A) for o in outs) / len(outs) loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() with torch.no_grad(): for kk, v in model.state_dict().items(): ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v) sched.step() if (ep + 1) % 20 == 0 or ep == args.epochs - 1: bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()} model.load_state_dict(ema); sr = solve_stats(model, te, dev) if sr > best: best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema} model.load_state_dict(bk) print(f"ep{ep+1} solve={sr:.3f}", flush=True) model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval() # LE + PTRM (noise + lambda-select) on test lams, fails = [], [] passk = lamsel = rand = 0 Lr, Sr = [], [] for gi, g in enumerate(te[:args.n_graphs]): lam0, c0 = lyap1_and_solved(model, g, dev, seed=gi, sigma=0.0) lams.append(lam0); fails.append(int(c0 > 0)) res = [lyap1_and_solved(model, g, dev, seed=1000 * gi + j, sigma=args.sigma) for j in range(args.K)] confs = np.array([c for _, c in res]); rl = np.array([l for l, _ in res]) solved = confs == 0 passk += int(solved.any()); lamsel += int(solved[rl.argmin()]); rand += int(solved[0]) Lr += rl.tolist(); Sr += solved.tolist() lams, fails = np.array(lams), np.array(fails) s, f = lams[fails == 0], lams[fails == 1] auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) Lr, Sr = np.array(Lr), np.array(Sr) pauc = (roc_auc_score(Sr.astype(int), -Lr) if roc_auc_score and Sr.any() and (~Sr).any() else float('nan')) n = len(te[:args.n_graphs]) print(f"[ppgn/{args.grad_mode}] solve={best:.3f} | LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} | " f"PTRM det={1 - fails.mean():.3f} passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)") base = f"ppgn_full_none_n50_k3_p0.2_T3_ns3_s{args.seed}" if args.grad_mode == 'full' \ else f"ppgn_1step_none_n50_k3_p0.2_T3_ns3_s{args.seed}" com = {'conv': 'ppgn', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed} json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'), indent=2) json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean()), 'lam_solved': (float(s.mean()) if len(s) else None), 'lam_unsolved': (float(f.mean()) if len(f) else None)}, open(os.path.join(OUT, f"le_color_{base}.json"), 'w'), indent=2) json.dump({**com, 'det': 1 - float(fails.mean()), 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n, 'perRoll': float(Sr.mean()), 'auroc': float(pauc)}}}, open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'), indent=2) print(" wrote color_/le_color_/ptrm_color_", base) if __name__ == "__main__": main()