diff options
Diffstat (limited to 'diag/train_color.py')
| -rw-r--r-- | diag/train_color.py | 347 |
1 files changed, 347 insertions, 0 deletions
diff --git a/diag/train_color.py b/diag/train_color.py new file mode 100644 index 0000000..36f8496 --- /dev/null +++ b/diag/train_color.py @@ -0,0 +1,347 @@ +"""Recursive (TRM-ish) GNN graph 3-coloring with swappable BACKBONE for the RRoG roadmap. + +--conv gin|gcn|sage|gat|gps : message-passing operator (gps = GraphGPS local MPNN + global + attention = TRM's original transformer backbone, on the graph). +--pe none|rwse|gsn|sub|lappe|all : input structural features (random sym-break [+ encoding]). +--contract : reverse-flossing lambda-penalty during training (force contraction; roadmap #4). +--grad_mode full|1step : TRM full recursion vs HRM 1-step gradient. +Self-supervised conflict/Potts loss; success = zero-conflict; EMA; deep supervision. +Modes: --mode train (saves ckpt + JSON) / --mode le. +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from torch_geometric.nn import GINConv, GCNConv, SAGEConv, GATConv, GPSConv, PNAConv +from torch_geometric.utils import degree as _pyg_degree + +# GPS uses scaled_dot_product_attention; force the MATH kernel so torch.autograd.functional.jvp +# (LE diagnostic / PTRM rollouts) has a double-backward-able implementation. +for _f in ('enable_flash_sdp', 'enable_mem_efficient_sdp', 'enable_math_sdp'): + try: + getattr(torch.backends.cuda, _f)(_f == 'enable_math_sdp') + except Exception: + pass + +OUT = '/home/yurenh2/rrog/runs' +CACHE = '/home/yurenh2/rrog/data/color_cache' + + +def gen(n, k, p, r, seed): + rng = np.random.default_rng(seed) + part = rng.integers(0, k, n) + src, dst = [], [] + for i in range(n): + for j in range(i + 1, n): + if part[i] != part[j] and rng.random() < p: + src += [i, j]; dst += [j, i] + ei = torch.tensor([src, dst], dtype=torch.long) if src else torch.zeros((2, 0), dtype=torch.long) + rf = torch.tensor(rng.standard_normal((n, r)), dtype=torch.float) + return {'n': n, 'edge_index': ei, 'rfeat': rf} + + +def make_split(split, n, k, p, r, count, seed0): + os.makedirs(CACHE, exist_ok=True) + fp = os.path.join(CACHE, f"{split}_n{n}_k{k}_p{p}_r{r}.pt") + if os.path.exists(fp): + return torch.load(fp, weights_only=False) + data = [gen(n, k, p, r, seed0 + i) for i in range(count)] + torch.save(data, fp) + return data + + +def _adj(edge_index, n): + A = np.zeros((n, n), dtype=np.float64) + ei = edge_index.numpy() + if ei.shape[1]: + A[ei[0], ei[1]] = 1.0 + return np.maximum(A, A.T) + + +def rwse(edge_index, n, K): + A = _adj(edge_index, n); deg = A.sum(1) + P = A / np.where(deg > 0, deg, 1.0)[:, None] + out = np.zeros((n, K), dtype=np.float32); M = np.eye(n) + for j in range(K): + M = M @ P; out[:, j] = np.diag(M) + return torch.from_numpy(out) + + +def gsn_feats(edge_index, n): + A = _adj(edge_index, n); deg = A.sum(1) + tri = (A @ A @ A).diagonal() / 2.0 + wedge = deg * (deg - 1) / 2.0 + return torch.tensor(np.stack([np.log1p(tri), np.log1p(wedge)], axis=1), dtype=torch.float) + + +def sub_feats(edge_index, n): + A = _adj(edge_index, n); deg = A.sum(1); A2 = A @ A + M = (A2 > 0).astype(np.float64) * (A == 0).astype(np.float64); np.fill_diagonal(M, 0.0) + d2 = M.sum(1) + tri = (A @ A2).diagonal() / 2.0 + clus = np.where(deg > 1, tri / (deg * (deg - 1) / 2.0), 0.0) + return torch.tensor(np.stack([np.log1p(deg), np.log1p(d2), clus], axis=1), dtype=torch.float) + + +def lappe_feats(edge_index, n, kpe=8): + A = _adj(edge_index, n); deg = A.sum(1) + di = np.where(deg > 0, 1.0 / np.sqrt(deg), 0.0) + L = np.eye(n) - di[:, None] * A * di[None, :] + _, V = np.linalg.eigh(L) + pe = V[:, 1:kpe + 1] + if pe.shape[1] < kpe: + pe = np.pad(pe, ((0, 0), (0, kpe - pe.shape[1]))) + return torch.tensor(pe, dtype=torch.float) + + +def featurize(graphs, pe, rwse_k): + def feat(g): + ei, n = g['edge_index'], g['n'] + if pe == 'rwse': return rwse(ei, n, rwse_k) + if pe == 'gsn': return gsn_feats(ei, n) + if pe == 'sub': return sub_feats(ei, n) + if pe == 'lappe': return lappe_feats(ei, n) + if pe == 'all': return torch.cat([rwse(ei, n, rwse_k), gsn_feats(ei, n), sub_feats(ei, n)], dim=1) + return None + for g in graphs: + e = feat(g) + g['xin'] = torch.cat([g['rfeat'], e], dim=1) if e is not None else g['rfeat'] + return graphs + + +def deg_hist(graphs): + md = 0; ds = [] + for g in graphs: + d = _pyg_degree(g['edge_index'][1], g['n'], dtype=torch.long) + md = max(md, int(d.max()) if d.numel() else 0); ds.append(d) + h = torch.zeros(md + 1, dtype=torch.long) + for d in ds: + h += torch.bincount(d, minlength=md + 1) + return h + + +def make_conv(conv, hidden, deg=None): + if conv == 'gin': + return GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) + if conv == 'gcn': + return GCNConv(hidden, hidden) + if conv == 'sage': + return SAGEConv(hidden, hidden) + if conv == 'gat': + return GATConv(hidden, hidden, heads=4, concat=False, add_self_loops=True) + if conv == 'pna': + return PNAConv(hidden, hidden, aggregators=['mean', 'min', 'max', 'std'], + scalers=['identity', 'amplification', 'attenuation'], deg=deg, towers=1) + if conv == 'gps': + local = GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) + return GPSConv(hidden, local, heads=4) + raise ValueError(conv) + + +class RecGINColor(nn.Module): + def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full', sigma=0.0, conv='gin', deg=None): + super().__init__() + self.conv_type = conv + self.lin_in = nn.Linear(in_dim, hidden) + self.convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(inner)]) + self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)]) + self.head = nn.Linear(hidden, k) + self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma + + def block(self, z, ei, batch=None): + if self.conv_type == 'gps' and batch is None: + batch = z.new_zeros(z.size(0), dtype=torch.long) + for conv, bn in zip(self.convs, self.bns): + z = conv(z, ei, batch) if self.conv_type == 'gps' else conv(z, ei) + z = bn(z).relu() + return z + + def _inner(self, z, h0, ei, noise, batch): + z = self.block(z + h0, ei, batch) + if noise and self.sigma > 0: + z = z + self.sigma * torch.randn_like(z) + return z + + def recurse(self, z, h0, ei, noise, batch, one_step=False): + if one_step: + with torch.no_grad(): + for _ in range(self.T - 1): + z = self._inner(z, h0, ei, noise, batch) + z = z.detach() + return self._inner(z, h0, ei, noise, batch) + for _ in range(self.T): + z = self._inner(z, h0, ei, noise, batch) + return z + + def forward(self, xin, ei, batch=None, noise=False): + h0 = self.lin_in(xin) + z = torch.zeros_like(h0) + outs = [] + for s in range(self.n_sup): + z = self.recurse(z, h0, ei, noise, batch, one_step=(self.grad_mode == '1step')) + outs.append(self.head(z)) + z = z.detach() + return outs + + +def conflict_loss(logits, ei): + p = F.softmax(logits, dim=-1) + return (p[ei[0]] * p[ei[1]]).sum(-1).mean() + + +@torch.no_grad() +def solve_stats(model, recs, dev, sample=None): + model.eval() + solved = 0; conf = 0.0; tot = 0 + for r in (recs[:sample] if sample else recs): + ei = r['edge_index'].to(dev) + col = model(r['xin'].to(dev), ei)[-1].argmax(-1) + c = (col[ei[0]] == col[ei[1]]).sum().item() // 2 + solved += int(c == 0); conf += c; tot += 1 + return solved / tot, conf / tot + + +def lyap1(model, xin, ei, n_steps, dev, seed=0): + g = torch.Generator(device=dev).manual_seed(seed) + h0 = model.lin_in(xin).detach() + z = torch.zeros_like(h0) + v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12) + def step_fn(zz): + return model.block(zz + h0, ei) + lam = 0.0 + for _ in range(n_steps): + z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v) + z = z_next.detach(); nv = Jv.norm() + lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach() + return lam / n_steps + + +def run_le(model, recs, dev, n_steps, n_graphs=300): + try: + from sklearn.metrics import roc_auc_score + except Exception: + roc_auc_score = None + model.eval() + lams, fails = [], [] + for i, r in enumerate(recs[:n_graphs]): + ei = r['edge_index'].to(dev); xin = r['xin'].to(dev) + with torch.no_grad(): + col = model(xin, ei)[-1].argmax(-1) + c = (col[ei[0]] == col[ei[1]]).sum().item() + fails.append(int(c > 0)) + lams.append(lyap1(model, xin, ei, n_steps, dev, seed=i)) + lams, fails = np.array(lams), np.array(fails) + s, f = lams[fails == 0], lams[fails == 1] + auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) + sep = (f.mean() - s.mean()) if len(s) and len(f) else float('nan') + print(f"[{model.conv_type}/{model.grad_mode}] LE n={len(lams)} fail={fails.mean():.2f} | " + f"SOLVED {s.mean() if len(s) else float('nan'):+.3f} UNSOLVED {f.mean() if len(f) else float('nan'):+.3f}" + f" sep={sep:+.3f} AUROC={auc:.3f} mean_lam={lams.mean():+.3f}") + return {'n': int(len(lams)), 'fail_rate': float(fails.mean()), 'auroc': float(auc), 'sep': float(sep), + 'lam_solved': (float(s.mean()) if len(s) else None), + 'lam_unsolved': (float(f.mean()) if len(f) else None), 'mean_lam': float(lams.mean())} + + +def lyap_penalty(model, x, ei, batch, target=-0.5): + h0 = model.lin_in(x) + with torch.no_grad(): + zr = model.recurse(torch.zeros_like(h0), h0.detach(), ei, False, batch) + v = torch.randn_like(zr); v = v / (v.norm() + 1e-12) + _, Jv = torch.autograd.functional.jvp(lambda zz: model.block(zz + h0, ei, batch), zr, v, create_graph=True) + return (torch.log(Jv.norm() + 1e-12) - target) ** 2 + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--mode', choices=['train', 'le'], default='train') + ap.add_argument('--conv', choices=['gin', 'gcn', 'sage', 'gat', 'pna', 'gps'], default='gin') + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--pe', choices=['none', 'rwse', 'gsn', 'sub', 'lappe', 'all'], default='none') + ap.add_argument('--contract', action='store_true') + ap.add_argument('--rwse_k', type=int, default=16) + ap.add_argument('--ckpt', default=None) + ap.add_argument('--n', type=int, default=50); ap.add_argument('--k', type=int, default=3) + ap.add_argument('--p', type=float, default=0.2); ap.add_argument('--r', type=int, default=8) + ap.add_argument('--hidden', type=int, default=128); ap.add_argument('--T', type=int, default=3) + ap.add_argument('--n_sup', type=int, default=3); ap.add_argument('--epochs', type=int, default=150) + ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--bs', type=int, default=32) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(OUT, exist_ok=True) + + if args.mode == 'le': + ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg'] + te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), + c.get('pe', 'none'), c.get('rwse_k', 16)) + deg = torch.tensor(c['deg']) if c.get('deg') else None + model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'], + grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev) + model.load_state_dict(ck['state']); model.eval() + res = run_le(model, te, dev, c['n_sup'] * c['T']) + base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '') + with open(os.path.join(OUT, f"le_{base}.json"), 'w') as fjs: + json.dump({'conv': c.get('conv', 'gin'), 'grad_mode': c['grad_mode'], 'pe': c.get('pe', 'none'), + 'contract': c.get('contract', False), 'seed': c.get('seed'), **res}, fjs, indent=2) + return + + te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), args.pe, args.rwse_k) + tr = featurize(make_split('train', args.n, args.k, args.p, args.r, 2000, 0), args.pe, args.rwse_k) + in_dim = tr[0]['xin'].shape[1] + data = [Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr] + trl = DataLoader(data, batch_size=args.bs, shuffle=True, drop_last=True) + deg = deg_hist(tr) if args.conv == 'pna' else None + model = RecGINColor(in_dim, args.hidden, args.k, args.T, args.n_sup, + grad_mode=args.grad_mode, conv=args.conv, deg=deg).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + + ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + t0 = time.time(); best_solve = -1; best = {}; best_state = None + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + outs = model(b.x, b.edge_index, b.batch, noise=False) + loss = sum(conflict_loss(o, b.edge_index) for o in outs) / len(outs) + if args.contract: + loss = loss + lyap_penalty(model, b.x, b.edge_index, b.batch) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + with torch.no_grad(): + for kk, v in model.state_dict().items(): + if torch.is_floating_point(v): + ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) + else: + ema[kk].copy_(v.detach()) + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + backup = {kk: v.detach().clone() for kk, v in model.state_dict().items()} + model.load_state_dict(ema) + sr, mc = solve_stats(model, te, dev, sample=300) + if sr > best_solve: + best_solve = sr; best = {'ep': ep + 1, 'solve_rate': round(sr, 4), 'mean_conflicts': round(mc, 3)} + best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema} + model.load_state_dict(backup) + print(f"ep{ep+1} solve_rate={sr:.3f} mean_conflicts={mc:.2f}", flush=True) + + sfx = ('_ctr' if args.contract else '') + tag = f"color_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}" + rep = {'task': 'graph3coloring', 'tag': tag, **vars(args), 'in_dim': in_dim, + 'sec': round(time.time() - t0, 1), **best} + print(f"[{tag}] best solve_rate={best.get('solve_rate')} @ep{best.get('ep')} ({rep['sec']}s)") + with open(os.path.join(OUT, f"{tag}.json"), 'w') as f: + json.dump(rep, f, indent=2) + torch.save({'state': best_state, 'cfg': {'in_dim': in_dim, 'hidden': args.hidden, 'k': args.k, + 'T': args.T, 'n_sup': args.n_sup, 'grad_mode': args.grad_mode, 'pe': args.pe, + 'rwse_k': args.rwse_k, 'contract': args.contract, 'conv': args.conv, 'seed': args.seed, + 'deg': (deg.tolist() if deg is not None else None)}}, + os.path.join(OUT, f"ckpt_{tag}.pt")) + print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt")) + + +if __name__ == "__main__": + main() |
