summaryrefslogtreecommitdiff
path: root/diag/ppgn_color.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/ppgn_color.py')
-rw-r--r--diag/ppgn_color.py209
1 files changed, 209 insertions, 0 deletions
diff --git a/diag/ppgn_color.py b/diag/ppgn_color.py
new file mode 100644
index 0000000..25db01b
--- /dev/null
+++ b/diag/ppgn_color.py
@@ -0,0 +1,209 @@
+"""H2: Recursive PPGN (higher-order, 3-WL) backbone on graph 3-coloring.
+
+State is a pair-tensor X[i,j,:] (not node features). The powerful block multiplies two
+channel-wise n x n matrices (the 3-WL operation): P = M1 @ M2; out = m3([X, P]). Recurse the
+pair-tensor (TRM-style deep supervision), pool to nodes (mean over j), decode colors.
+Self-contained: trains (EMA, best solve), then LE diagnostic + PTRM noise/lambda-select;
+writes color_/le_/ptrm_ JSONs (conv='ppgn') so diag/aggregate.py folds it into the big table.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ppgn_color.py --grad_mode full --seed 0
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diag.train_color import make_split, featurize, OUT
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+
+def dense_A(edge_index, n):
+ A = torch.zeros(n, n)
+ if edge_index.shape[1]:
+ A[edge_index[0], edge_index[1]] = 1.0
+ return torch.maximum(A, A.t())
+
+
+def mlp(di, dh, do):
+ return nn.Sequential(nn.Linear(di, dh), nn.ReLU(), nn.Linear(dh, do))
+
+
+class RecPPGN(nn.Module):
+ def __init__(self, in_dim, hidden=64, k=3, T=3, n_sup=3, grad_mode='full', sigma=0.0):
+ super().__init__()
+ self.node_in = nn.Linear(in_dim, hidden)
+ self.adj_emb = nn.Embedding(2, hidden)
+ self.m1 = mlp(hidden, hidden, hidden); self.m2 = mlp(hidden, hidden, hidden)
+ self.m3 = mlp(2 * hidden, hidden, hidden)
+ self.ln = nn.LayerNorm(hidden)
+ self.head = nn.Linear(hidden, k)
+ self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma
+
+ def init_pair(self, xin, A): # xin [B,n,in], A [B,n,n]
+ h = self.node_in(xin)
+ X = h.unsqueeze(2) + h.unsqueeze(1) # [B,n,n,hidden]
+ return X + self.adj_emb(A.long())
+
+ def block(self, X): # X [B,n,n,h]
+ M1, M2 = self.m1(X), self.m2(X)
+ P = torch.einsum('bikc,bkjc->bijc', M1, M2) / X.shape[1]
+ return self.ln(self.m3(torch.cat([X, P], dim=-1)))
+
+ def _inner(self, X, X0, noise):
+ X = self.block(X + X0)
+ if noise and self.sigma > 0:
+ X = X + self.sigma * torch.randn_like(X)
+ return X
+
+ def recurse(self, X, X0, noise, one_step=False):
+ if one_step:
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ X = self._inner(X, X0, noise)
+ X = X.detach()
+ return self._inner(X, X0, noise)
+ for _ in range(self.T):
+ X = self._inner(X, X0, noise)
+ return X
+
+ def forward(self, xin, A, noise=False):
+ X0 = self.init_pair(xin, A)
+ X = torch.zeros_like(X0)
+ outs = []
+ for s in range(self.n_sup):
+ X = self.recurse(X, X0, noise, one_step=(self.grad_mode == '1step'))
+ outs.append(self.head(X.mean(2))) # pool over j -> [B,n,k]
+ X = X.detach()
+ return outs
+
+
+def conflict_loss(logits, A): # logits [B,n,k], A [B,n,n]
+ p = F.softmax(logits, dim=-1)
+ return (torch.einsum('bik,bjk->bij', p, p) * A).sum() / (A.sum() + 1e-9)
+
+
+def batches(graphs, bs, shuffle, dev):
+ idx = np.arange(len(graphs))
+ if shuffle:
+ np.random.shuffle(idx)
+ for s in range(0, len(idx) - (len(idx) % bs if shuffle else 0), bs):
+ sel = idx[s:s + bs]
+ if len(sel) == 0:
+ continue
+ X = torch.stack([graphs[i]['xin'] for i in sel]).to(dev)
+ A = torch.stack([graphs[i]['A'] for i in sel]).to(dev)
+ yield X, A
+
+
+@torch.no_grad()
+def solve_stats(model, graphs, dev, sample=300):
+ model.eval(); solved = 0; tot = 0
+ for g in graphs[:sample]:
+ A = g['A'].unsqueeze(0).to(dev)
+ col = model(g['xin'].unsqueeze(0).to(dev), A)[-1][0].argmax(-1)
+ conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2
+ solved += int(conf == 0); tot += 1
+ return solved / tot
+
+
+def lyap1_and_solved(model, g, dev, seed, sigma=0.0):
+ gen = torch.Generator(device=dev).manual_seed(seed)
+ xin = g['xin'].unsqueeze(0).to(dev); A = g['A'].unsqueeze(0).to(dev)
+ X0 = model.init_pair(xin, A)
+ X = torch.zeros_like(X0)
+ v = torch.randn(X.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12)
+ lam = 0.0
+ for _ in range(model.n_sup * model.T):
+ Xd, Jv = torch.autograd.functional.jvp(lambda XX: model.block(XX + X0), X, v)
+ nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
+ X = Xd.detach()
+ if sigma > 0:
+ X = X + sigma * torch.randn(X.shape, generator=gen, device=dev)
+ col = model.head(X.mean(2))[0].argmax(-1)
+ conf = ((col.unsqueeze(0) == col.unsqueeze(1)) & (A[0] > 0)).sum().item() // 2
+ return lam / (model.n_sup * model.T), conf
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--bs', type=int, default=16)
+ ap.add_argument('--hidden', type=int, default=64); ap.add_argument('--seed', type=int, default=0)
+ ap.add_argument('--sigma', type=float, default=0.2); ap.add_argument('--K', type=int, default=16)
+ ap.add_argument('--n_graphs', type=int, default=150)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ for g in tr + te:
+ g['A'] = dense_A(g['edge_index'], g['n'])
+ in_dim = tr[0]['xin'].shape[1]
+
+ model = RecPPGN(in_dim, args.hidden, 3, grad_mode=args.grad_mode).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ t0 = time.time(); best = -1; best_state = None
+ for ep in range(args.epochs):
+ model.train()
+ for X, A in batches(tr, args.bs, True, dev):
+ opt.zero_grad()
+ outs = model(X, A, noise=False)
+ loss = sum(conflict_loss(o, A) for o in outs) / len(outs)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr = solve_stats(model, te, dev)
+ if sr > best:
+ best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(bk)
+ print(f"ep{ep+1} solve={sr:.3f}", flush=True)
+ model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval()
+
+ # LE + PTRM (noise + lambda-select) on test
+ lams, fails = [], []
+ passk = lamsel = rand = 0
+ Lr, Sr = [], []
+ for gi, g in enumerate(te[:args.n_graphs]):
+ lam0, c0 = lyap1_and_solved(model, g, dev, seed=gi, sigma=0.0)
+ lams.append(lam0); fails.append(int(c0 > 0))
+ res = [lyap1_and_solved(model, g, dev, seed=1000 * gi + j, sigma=args.sigma) for j in range(args.K)]
+ confs = np.array([c for _, c in res]); rl = np.array([l for l, _ in res])
+ solved = confs == 0
+ passk += int(solved.any()); lamsel += int(solved[rl.argmin()]); rand += int(solved[0])
+ Lr += rl.tolist(); Sr += solved.tolist()
+ lams, fails = np.array(lams), np.array(fails)
+ s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ Lr, Sr = np.array(Lr), np.array(Sr)
+ pauc = (roc_auc_score(Sr.astype(int), -Lr) if roc_auc_score and Sr.any() and (~Sr).any() else float('nan'))
+ n = len(te[:args.n_graphs])
+ print(f"[ppgn/{args.grad_mode}] solve={best:.3f} | LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} | "
+ f"PTRM det={1 - fails.mean():.3f} passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)")
+
+ base = f"ppgn_full_none_n50_k3_p0.2_T3_ns3_s{args.seed}" if args.grad_mode == 'full' \
+ else f"ppgn_1step_none_n50_k3_p0.2_T3_ns3_s{args.seed}"
+ com = {'conv': 'ppgn', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed}
+ json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'), indent=2)
+ json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean()),
+ 'lam_solved': (float(s.mean()) if len(s) else None),
+ 'lam_unsolved': (float(f.mean()) if len(f) else None)},
+ open(os.path.join(OUT, f"le_color_{base}.json"), 'w'), indent=2)
+ json.dump({**com, 'det': 1 - float(fails.mean()),
+ 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n,
+ 'perRoll': float(Sr.mean()), 'auroc': float(pauc)}}},
+ open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'), indent=2)
+ print(" wrote color_/le_color_/ptrm_color_", base)
+
+
+if __name__ == "__main__":
+ main()