summaryrefslogtreecommitdiff
path: root/diag/cin_color.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/cin_color.py')
-rw-r--r--diag/cin_color.py144
1 files changed, 144 insertions, 0 deletions
diff --git a/diag/cin_color.py b/diag/cin_color.py
new file mode 100644
index 0000000..324215f
--- /dev/null
+++ b/diag/cin_color.py
@@ -0,0 +1,144 @@
+"""H4: Recursive CIN-lite (topological / cell-complex) backbone on graph 3-coloring.
+
+Augment each graph with ring 2-cells from the cycle basis: add one hypernode per basis cycle,
+connected to its member nodes. Messages flow node->ring->node (topological message passing over
+rings). Run the shared recursive GIN on the augmented (nodes + ring-cells) graph; decode colors
+on the ORIGINAL nodes only. Self-contained: train (EMA, best solve) + LE + PTRM; writes
+color_/le_/ptrm_ JSON (conv='cin') for diag/aggregate.py.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/cin_color.py --grad_mode full --seed 0
+"""
+import argparse, json, os, time
+import numpy as np
+import networkx as nx
+import torch
+import torch.nn.functional as F
+from torch_geometric.data import Data, Batch
+from diag.train_color import make_split, featurize, RecGINColor, lyap1, OUT
+try:
+ from sklearn.metrics import roc_auc_score
+except Exception:
+ roc_auc_score = None
+
+N = 50
+
+
+def dense_A(edge_index, n):
+ A = torch.zeros(n, n)
+ if edge_index.shape[1]:
+ A[edge_index[0], edge_index[1]] = 1.0
+ return torch.maximum(A, A.t())
+
+
+def augment(g):
+ n = g['n']; ei = g['edge_index']
+ G = nx.Graph(); G.add_nodes_from(range(n))
+ if ei.shape[1]:
+ G.add_edges_from(ei.t().tolist())
+ rings = nx.cycle_basis(G)
+ R = len(rings)
+ src, dst = ei[0].tolist(), ei[1].tolist()
+ for r, cyc in enumerate(rings):
+ rn = n + r
+ for node in cyc:
+ src += [node, rn]; dst += [rn, node]
+ aug_ei = torch.tensor([src, dst], dtype=torch.long) if src else torch.zeros((2, 0), dtype=torch.long)
+ d = g['rfeat'].shape[1]
+ nf = torch.cat([g['rfeat'], torch.tensor([[1.0, 0.0]]).repeat(n, 1)], dim=1)
+ rf = torch.cat([torch.zeros(R, d), torch.tensor([[0.0, 1.0]]).repeat(R, 1)], dim=1) if R else torch.zeros(0, d + 2)
+ return Data(x=torch.cat([nf, rf], dim=0), edge_index=aug_ei, num_nodes=n + R)
+
+
+def conf_of(logits, A):
+ col = logits.argmax(-1)
+ return int(((col.unsqueeze(0) == col.unsqueeze(1)) & (A > 0)).sum().item() // 2)
+
+
+@torch.no_grad()
+def solve_rate(model, graphs, dev, sample=300):
+ model.eval(); solved = 0
+ for g in graphs[:sample]:
+ d = g['aug'].to(dev)
+ lg = model(d.x, d.edge_index)[-1][:N]
+ solved += int(conf_of(lg, g['A'].to(dev)) == 0)
+ return solved / len(graphs[:sample])
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--epochs', type=int, default=150); ap.add_argument('--M', type=int, default=16)
+ ap.add_argument('--seed', type=int, default=0); ap.add_argument('--sigma', type=float, default=0.2)
+ ap.add_argument('--K', type=int, default=16); ap.add_argument('--n_graphs', type=int, default=150)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ rng = np.random.default_rng(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ tr = featurize(make_split('train', 50, 3, 0.2, 8, 2000, 0), 'none', 16)
+ te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), 'none', 16)
+ for g in tr + te:
+ g['A'] = dense_A(g['edge_index'], g['n']); g['aug'] = augment(g)
+ in_dim = tr[0]['rfeat'].shape[1] + 2
+ model = RecGINColor(in_dim, 128, 3, grad_mode=args.grad_mode, conv='gin').to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+
+ t0 = time.time(); best = -1; best_state = None
+ for ep in range(args.epochs):
+ model.train(); order = rng.permutation(len(tr))
+ for s0 in range(0, len(order) - args.M, args.M):
+ sel = order[s0:s0 + args.M]
+ b = Batch.from_data_list([tr[i]['aug'] for i in sel]).to(dev)
+ opt.zero_grad()
+ logits = model(b.x, b.edge_index, b.batch)[-1]
+ orig = torch.stack([logits[b.ptr[gi]:b.ptr[gi] + N] for gi in range(len(sel))]) # [M,N,k]
+ A = torch.stack([tr[i]['A'] for i in sel]).to(dev)
+ p = F.softmax(orig, -1)
+ loss = (torch.einsum('bik,bjk->bij', p, p) * A).sum() / (A.sum() + 1e-9)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) if torch.is_floating_point(v) else ema[kk].copy_(v)
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ bk = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema); sr = solve_rate(model, te, dev)
+ if sr > best:
+ best = sr; best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(bk)
+ print(f"ep{ep+1} solve={sr:.3f}", flush=True)
+ model.load_state_dict({kk: best_state[kk].to(dev) for kk in best_state}); model.eval()
+
+ nstep = model.n_sup * model.T
+ lams, fails = [], []; passk = lamsel = rand = 0; Lr, Sr = [], []
+ for gi, g in enumerate(te[:args.n_graphs]):
+ d = g['aug'].to(dev); A = g['A'].to(dev)
+ lam0 = lyap1(model, d.x, d.edge_index, nstep, dev, seed=gi)
+ c0 = conf_of(model(d.x, d.edge_index)[-1][:N], A)
+ lams.append(lam0); fails.append(int(c0 > 0))
+ confs, rl = [], []
+ for j in range(args.K):
+ confs.append(conf_of(model(d.x, d.edge_index, noise=True)[-1][:N], A))
+ rl.append(lyap1(model, d.x, d.edge_index, nstep, dev, seed=1000 * gi + j))
+ confs, rl = np.array(confs), np.array(rl); sv = confs == 0
+ passk += int(sv.any()); lamsel += int(sv[rl.argmin()]); rand += int(sv[0])
+ Lr += rl.tolist(); Sr += sv.tolist()
+ lams, fails = np.array(lams), np.array(fails); s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ n = len(te[:args.n_graphs])
+ print(f"[cin/{args.grad_mode}] solve={best:.3f} LE AUROC={auc:.3f} mean_lam={lams.mean():+.3f} "
+ f"passK={passk/n:.3f} lamsel={lamsel/n:.3f} ({time.time()-t0:.0f}s)")
+ base = f"cin_{args.grad_mode}_none_n50_k3_p0.2_T3_ns3_s{args.seed}"
+ com = {'conv': 'cin', 'pe': 'none', 'grad_mode': args.grad_mode, 'contract': False, 'seed': args.seed}
+ json.dump({**com, 'solve_rate': best}, open(os.path.join(OUT, f"color_{base}.json"), 'w'))
+ json.dump({**com, 'auroc': float(auc), 'mean_lam': float(lams.mean())}, open(os.path.join(OUT, f"le_color_{base}.json"), 'w'))
+ json.dump({**com, 'det': 1 - float(fails.mean()),
+ 'sigmas': {'0.2': {'passk': passk / n, 'lamsel': lamsel / n, 'random': rand / n}}},
+ open(os.path.join(OUT, f"ptrm_color_{base}.json"), 'w'))
+ print(" wrote", base)
+
+
+if __name__ == "__main__":
+ main()