summaryrefslogtreecommitdiff
path: root/diag/train_color.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/train_color.py')
-rw-r--r--diag/train_color.py347
1 files changed, 347 insertions, 0 deletions
diff --git a/diag/train_color.py b/diag/train_color.py
new file mode 100644
index 0000000..36f8496
--- /dev/null
+++ b/diag/train_color.py
@@ -0,0 +1,347 @@
+"""Recursive (TRM-ish) GNN graph 3-coloring with swappable BACKBONE for the RRoG roadmap.
+
+--conv gin|gcn|sage|gat|gps : message-passing operator (gps = GraphGPS local MPNN + global
+ attention = TRM's original transformer backbone, on the graph).
+--pe none|rwse|gsn|sub|lappe|all : input structural features (random sym-break [+ encoding]).
+--contract : reverse-flossing lambda-penalty during training (force contraction; roadmap #4).
+--grad_mode full|1step : TRM full recursion vs HRM 1-step gradient.
+Self-supervised conflict/Potts loss; success = zero-conflict; EMA; deep supervision.
+Modes: --mode train (saves ckpt + JSON) / --mode le.
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import GINConv, GCNConv, SAGEConv, GATConv, GPSConv, PNAConv
+from torch_geometric.utils import degree as _pyg_degree
+
+# GPS uses scaled_dot_product_attention; force the MATH kernel so torch.autograd.functional.jvp
+# (LE diagnostic / PTRM rollouts) has a double-backward-able implementation.
+for _f in ('enable_flash_sdp', 'enable_mem_efficient_sdp', 'enable_math_sdp'):
+ try:
+ getattr(torch.backends.cuda, _f)(_f == 'enable_math_sdp')
+ except Exception:
+ pass
+
+OUT = '/home/yurenh2/rrog/runs'
+CACHE = '/home/yurenh2/rrog/data/color_cache'
+
+
+def gen(n, k, p, r, seed):
+ rng = np.random.default_rng(seed)
+ part = rng.integers(0, k, n)
+ src, dst = [], []
+ for i in range(n):
+ for j in range(i + 1, n):
+ if part[i] != part[j] and rng.random() < p:
+ src += [i, j]; dst += [j, i]
+ ei = torch.tensor([src, dst], dtype=torch.long) if src else torch.zeros((2, 0), dtype=torch.long)
+ rf = torch.tensor(rng.standard_normal((n, r)), dtype=torch.float)
+ return {'n': n, 'edge_index': ei, 'rfeat': rf}
+
+
+def make_split(split, n, k, p, r, count, seed0):
+ os.makedirs(CACHE, exist_ok=True)
+ fp = os.path.join(CACHE, f"{split}_n{n}_k{k}_p{p}_r{r}.pt")
+ if os.path.exists(fp):
+ return torch.load(fp, weights_only=False)
+ data = [gen(n, k, p, r, seed0 + i) for i in range(count)]
+ torch.save(data, fp)
+ return data
+
+
+def _adj(edge_index, n):
+ A = np.zeros((n, n), dtype=np.float64)
+ ei = edge_index.numpy()
+ if ei.shape[1]:
+ A[ei[0], ei[1]] = 1.0
+ return np.maximum(A, A.T)
+
+
+def rwse(edge_index, n, K):
+ A = _adj(edge_index, n); deg = A.sum(1)
+ P = A / np.where(deg > 0, deg, 1.0)[:, None]
+ out = np.zeros((n, K), dtype=np.float32); M = np.eye(n)
+ for j in range(K):
+ M = M @ P; out[:, j] = np.diag(M)
+ return torch.from_numpy(out)
+
+
+def gsn_feats(edge_index, n):
+ A = _adj(edge_index, n); deg = A.sum(1)
+ tri = (A @ A @ A).diagonal() / 2.0
+ wedge = deg * (deg - 1) / 2.0
+ return torch.tensor(np.stack([np.log1p(tri), np.log1p(wedge)], axis=1), dtype=torch.float)
+
+
+def sub_feats(edge_index, n):
+ A = _adj(edge_index, n); deg = A.sum(1); A2 = A @ A
+ M = (A2 > 0).astype(np.float64) * (A == 0).astype(np.float64); np.fill_diagonal(M, 0.0)
+ d2 = M.sum(1)
+ tri = (A @ A2).diagonal() / 2.0
+ clus = np.where(deg > 1, tri / (deg * (deg - 1) / 2.0), 0.0)
+ return torch.tensor(np.stack([np.log1p(deg), np.log1p(d2), clus], axis=1), dtype=torch.float)
+
+
+def lappe_feats(edge_index, n, kpe=8):
+ A = _adj(edge_index, n); deg = A.sum(1)
+ di = np.where(deg > 0, 1.0 / np.sqrt(deg), 0.0)
+ L = np.eye(n) - di[:, None] * A * di[None, :]
+ _, V = np.linalg.eigh(L)
+ pe = V[:, 1:kpe + 1]
+ if pe.shape[1] < kpe:
+ pe = np.pad(pe, ((0, 0), (0, kpe - pe.shape[1])))
+ return torch.tensor(pe, dtype=torch.float)
+
+
+def featurize(graphs, pe, rwse_k):
+ def feat(g):
+ ei, n = g['edge_index'], g['n']
+ if pe == 'rwse': return rwse(ei, n, rwse_k)
+ if pe == 'gsn': return gsn_feats(ei, n)
+ if pe == 'sub': return sub_feats(ei, n)
+ if pe == 'lappe': return lappe_feats(ei, n)
+ if pe == 'all': return torch.cat([rwse(ei, n, rwse_k), gsn_feats(ei, n), sub_feats(ei, n)], dim=1)
+ return None
+ for g in graphs:
+ e = feat(g)
+ g['xin'] = torch.cat([g['rfeat'], e], dim=1) if e is not None else g['rfeat']
+ return graphs
+
+
+def deg_hist(graphs):
+ md = 0; ds = []
+ for g in graphs:
+ d = _pyg_degree(g['edge_index'][1], g['n'], dtype=torch.long)
+ md = max(md, int(d.max()) if d.numel() else 0); ds.append(d)
+ h = torch.zeros(md + 1, dtype=torch.long)
+ for d in ds:
+ h += torch.bincount(d, minlength=md + 1)
+ return h
+
+
+def make_conv(conv, hidden, deg=None):
+ if conv == 'gin':
+ return GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ if conv == 'gcn':
+ return GCNConv(hidden, hidden)
+ if conv == 'sage':
+ return SAGEConv(hidden, hidden)
+ if conv == 'gat':
+ return GATConv(hidden, hidden, heads=4, concat=False, add_self_loops=True)
+ if conv == 'pna':
+ return PNAConv(hidden, hidden, aggregators=['mean', 'min', 'max', 'std'],
+ scalers=['identity', 'amplification', 'attenuation'], deg=deg, towers=1)
+ if conv == 'gps':
+ local = GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ return GPSConv(hidden, local, heads=4)
+ raise ValueError(conv)
+
+
+class RecGINColor(nn.Module):
+ def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full', sigma=0.0, conv='gin', deg=None):
+ super().__init__()
+ self.conv_type = conv
+ self.lin_in = nn.Linear(in_dim, hidden)
+ self.convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(inner)])
+ self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)])
+ self.head = nn.Linear(hidden, k)
+ self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma
+
+ def block(self, z, ei, batch=None):
+ if self.conv_type == 'gps' and batch is None:
+ batch = z.new_zeros(z.size(0), dtype=torch.long)
+ for conv, bn in zip(self.convs, self.bns):
+ z = conv(z, ei, batch) if self.conv_type == 'gps' else conv(z, ei)
+ z = bn(z).relu()
+ return z
+
+ def _inner(self, z, h0, ei, noise, batch):
+ z = self.block(z + h0, ei, batch)
+ if noise and self.sigma > 0:
+ z = z + self.sigma * torch.randn_like(z)
+ return z
+
+ def recurse(self, z, h0, ei, noise, batch, one_step=False):
+ if one_step:
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._inner(z, h0, ei, noise, batch)
+ z = z.detach()
+ return self._inner(z, h0, ei, noise, batch)
+ for _ in range(self.T):
+ z = self._inner(z, h0, ei, noise, batch)
+ return z
+
+ def forward(self, xin, ei, batch=None, noise=False):
+ h0 = self.lin_in(xin)
+ z = torch.zeros_like(h0)
+ outs = []
+ for s in range(self.n_sup):
+ z = self.recurse(z, h0, ei, noise, batch, one_step=(self.grad_mode == '1step'))
+ outs.append(self.head(z))
+ z = z.detach()
+ return outs
+
+
+def conflict_loss(logits, ei):
+ p = F.softmax(logits, dim=-1)
+ return (p[ei[0]] * p[ei[1]]).sum(-1).mean()
+
+
+@torch.no_grad()
+def solve_stats(model, recs, dev, sample=None):
+ model.eval()
+ solved = 0; conf = 0.0; tot = 0
+ for r in (recs[:sample] if sample else recs):
+ ei = r['edge_index'].to(dev)
+ col = model(r['xin'].to(dev), ei)[-1].argmax(-1)
+ c = (col[ei[0]] == col[ei[1]]).sum().item() // 2
+ solved += int(c == 0); conf += c; tot += 1
+ return solved / tot, conf / tot
+
+
+def lyap1(model, xin, ei, n_steps, dev, seed=0):
+ g = torch.Generator(device=dev).manual_seed(seed)
+ h0 = model.lin_in(xin).detach()
+ z = torch.zeros_like(h0)
+ v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
+ def step_fn(zz):
+ return model.block(zz + h0, ei)
+ lam = 0.0
+ for _ in range(n_steps):
+ z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
+ z = z_next.detach(); nv = Jv.norm()
+ lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
+ return lam / n_steps
+
+
+def run_le(model, recs, dev, n_steps, n_graphs=300):
+ try:
+ from sklearn.metrics import roc_auc_score
+ except Exception:
+ roc_auc_score = None
+ model.eval()
+ lams, fails = [], []
+ for i, r in enumerate(recs[:n_graphs]):
+ ei = r['edge_index'].to(dev); xin = r['xin'].to(dev)
+ with torch.no_grad():
+ col = model(xin, ei)[-1].argmax(-1)
+ c = (col[ei[0]] == col[ei[1]]).sum().item()
+ fails.append(int(c > 0))
+ lams.append(lyap1(model, xin, ei, n_steps, dev, seed=i))
+ lams, fails = np.array(lams), np.array(fails)
+ s, f = lams[fails == 0], lams[fails == 1]
+ auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan'))
+ sep = (f.mean() - s.mean()) if len(s) and len(f) else float('nan')
+ print(f"[{model.conv_type}/{model.grad_mode}] LE n={len(lams)} fail={fails.mean():.2f} | "
+ f"SOLVED {s.mean() if len(s) else float('nan'):+.3f} UNSOLVED {f.mean() if len(f) else float('nan'):+.3f}"
+ f" sep={sep:+.3f} AUROC={auc:.3f} mean_lam={lams.mean():+.3f}")
+ return {'n': int(len(lams)), 'fail_rate': float(fails.mean()), 'auroc': float(auc), 'sep': float(sep),
+ 'lam_solved': (float(s.mean()) if len(s) else None),
+ 'lam_unsolved': (float(f.mean()) if len(f) else None), 'mean_lam': float(lams.mean())}
+
+
+def lyap_penalty(model, x, ei, batch, target=-0.5):
+ h0 = model.lin_in(x)
+ with torch.no_grad():
+ zr = model.recurse(torch.zeros_like(h0), h0.detach(), ei, False, batch)
+ v = torch.randn_like(zr); v = v / (v.norm() + 1e-12)
+ _, Jv = torch.autograd.functional.jvp(lambda zz: model.block(zz + h0, ei, batch), zr, v, create_graph=True)
+ return (torch.log(Jv.norm() + 1e-12) - target) ** 2
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--mode', choices=['train', 'le'], default='train')
+ ap.add_argument('--conv', choices=['gin', 'gcn', 'sage', 'gat', 'pna', 'gps'], default='gin')
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--pe', choices=['none', 'rwse', 'gsn', 'sub', 'lappe', 'all'], default='none')
+ ap.add_argument('--contract', action='store_true')
+ ap.add_argument('--rwse_k', type=int, default=16)
+ ap.add_argument('--ckpt', default=None)
+ ap.add_argument('--n', type=int, default=50); ap.add_argument('--k', type=int, default=3)
+ ap.add_argument('--p', type=float, default=0.2); ap.add_argument('--r', type=int, default=8)
+ ap.add_argument('--hidden', type=int, default=128); ap.add_argument('--T', type=int, default=3)
+ ap.add_argument('--n_sup', type=int, default=3); ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--bs', type=int, default=32)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ if args.mode == 'le':
+ ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg']
+ te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000),
+ c.get('pe', 'none'), c.get('rwse_k', 16))
+ deg = torch.tensor(c['deg']) if c.get('deg') else None
+ model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'],
+ grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev)
+ model.load_state_dict(ck['state']); model.eval()
+ res = run_le(model, te, dev, c['n_sup'] * c['T'])
+ base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
+ with open(os.path.join(OUT, f"le_{base}.json"), 'w') as fjs:
+ json.dump({'conv': c.get('conv', 'gin'), 'grad_mode': c['grad_mode'], 'pe': c.get('pe', 'none'),
+ 'contract': c.get('contract', False), 'seed': c.get('seed'), **res}, fjs, indent=2)
+ return
+
+ te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), args.pe, args.rwse_k)
+ tr = featurize(make_split('train', args.n, args.k, args.p, args.r, 2000, 0), args.pe, args.rwse_k)
+ in_dim = tr[0]['xin'].shape[1]
+ data = [Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr]
+ trl = DataLoader(data, batch_size=args.bs, shuffle=True, drop_last=True)
+ deg = deg_hist(tr) if args.conv == 'pna' else None
+ model = RecGINColor(in_dim, args.hidden, args.k, args.T, args.n_sup,
+ grad_mode=args.grad_mode, conv=args.conv, deg=deg).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+
+ ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ t0 = time.time(); best_solve = -1; best = {}; best_state = None
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ outs = model(b.x, b.edge_index, b.batch, noise=False)
+ loss = sum(conflict_loss(o, b.edge_index) for o in outs) / len(outs)
+ if args.contract:
+ loss = loss + lyap_penalty(model, b.x, b.edge_index, b.batch)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ with torch.no_grad():
+ for kk, v in model.state_dict().items():
+ if torch.is_floating_point(v):
+ ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001)
+ else:
+ ema[kk].copy_(v.detach())
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ backup = {kk: v.detach().clone() for kk, v in model.state_dict().items()}
+ model.load_state_dict(ema)
+ sr, mc = solve_stats(model, te, dev, sample=300)
+ if sr > best_solve:
+ best_solve = sr; best = {'ep': ep + 1, 'solve_rate': round(sr, 4), 'mean_conflicts': round(mc, 3)}
+ best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema}
+ model.load_state_dict(backup)
+ print(f"ep{ep+1} solve_rate={sr:.3f} mean_conflicts={mc:.2f}", flush=True)
+
+ sfx = ('_ctr' if args.contract else '')
+ tag = f"color_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}"
+ rep = {'task': 'graph3coloring', 'tag': tag, **vars(args), 'in_dim': in_dim,
+ 'sec': round(time.time() - t0, 1), **best}
+ print(f"[{tag}] best solve_rate={best.get('solve_rate')} @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ torch.save({'state': best_state, 'cfg': {'in_dim': in_dim, 'hidden': args.hidden, 'k': args.k,
+ 'T': args.T, 'n_sup': args.n_sup, 'grad_mode': args.grad_mode, 'pe': args.pe,
+ 'rwse_k': args.rwse_k, 'contract': args.contract, 'conv': args.conv, 'seed': args.seed,
+ 'deg': (deg.tolist() if deg is not None else None)}},
+ os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))
+
+
+if __name__ == "__main__":
+ main()