"""RRoG/TRM-on-GNN graph 3-coloring. The graph is encoded once into a fixed per-node context. Recursion then refines hidden state with a shared compute block that never reads edge_index. This is the RRoG split: the GNN encoder supplies the view/context x, TRM-style recurrence supplies computation. --conv gin|gcn|sage|gat|gps : message-passing operator used only by the one-shot encoder (gps = GraphGPS local MPNN + global attention). --pe none|rwse|gsn|sub|lappe|all : input structural features (random sym-break [+ encoding]). --contract : reverse-flossing lambda-penalty during training (force contraction; roadmap #4). --grad_mode full|1step : TRM full recursion vs HRM 1-step gradient. Self-supervised conflict/Potts loss; success = zero-conflict; EMA; deep supervision. Modes: --mode train (saves ckpt + JSON) / --mode le. """ import argparse, json, os, time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.loader import DataLoader from torch_geometric.nn import GINConv, GCNConv, SAGEConv, GATConv, GPSConv, PNAConv from torch_geometric.utils import degree as _pyg_degree # GPS uses scaled_dot_product_attention; force the MATH kernel so torch.autograd.functional.jvp # (LE diagnostic / PTRM rollouts) has a double-backward-able implementation. for _f in ('enable_flash_sdp', 'enable_mem_efficient_sdp', 'enable_math_sdp'): try: getattr(torch.backends.cuda, _f)(_f == 'enable_math_sdp') except Exception: pass OUT = '/home/yurenh2/rrog/runs' CACHE = '/home/yurenh2/rrog/data/color_cache' def gen(n, k, p, r, seed): rng = np.random.default_rng(seed) part = rng.integers(0, k, n) src, dst = [], [] for i in range(n): for j in range(i + 1, n): if part[i] != part[j] and rng.random() < p: src += [i, j]; dst += [j, i] ei = torch.tensor([src, dst], dtype=torch.long) if src else torch.zeros((2, 0), dtype=torch.long) rf = torch.tensor(rng.standard_normal((n, r)), dtype=torch.float) return {'n': n, 'edge_index': ei, 'rfeat': rf} def make_split(split, n, k, p, r, count, seed0): os.makedirs(CACHE, exist_ok=True) fp = os.path.join(CACHE, f"{split}_n{n}_k{k}_p{p}_r{r}.pt") if os.path.exists(fp): return torch.load(fp, weights_only=False) data = [gen(n, k, p, r, seed0 + i) for i in range(count)] torch.save(data, fp) return data def _adj(edge_index, n): A = np.zeros((n, n), dtype=np.float64) ei = edge_index.numpy() if ei.shape[1]: A[ei[0], ei[1]] = 1.0 return np.maximum(A, A.T) def rwse(edge_index, n, K): A = _adj(edge_index, n); deg = A.sum(1) P = A / np.where(deg > 0, deg, 1.0)[:, None] out = np.zeros((n, K), dtype=np.float32); M = np.eye(n) for j in range(K): M = M @ P; out[:, j] = np.diag(M) return torch.from_numpy(out) def gsn_feats(edge_index, n): A = _adj(edge_index, n); deg = A.sum(1) tri = (A @ A @ A).diagonal() / 2.0 wedge = deg * (deg - 1) / 2.0 return torch.tensor(np.stack([np.log1p(tri), np.log1p(wedge)], axis=1), dtype=torch.float) def sub_feats(edge_index, n): A = _adj(edge_index, n); deg = A.sum(1); A2 = A @ A M = (A2 > 0).astype(np.float64) * (A == 0).astype(np.float64); np.fill_diagonal(M, 0.0) d2 = M.sum(1) tri = (A @ A2).diagonal() / 2.0 clus = np.where(deg > 1, tri / (deg * (deg - 1) / 2.0), 0.0) return torch.tensor(np.stack([np.log1p(deg), np.log1p(d2), clus], axis=1), dtype=torch.float) def lappe_feats(edge_index, n, kpe=8): A = _adj(edge_index, n); deg = A.sum(1) di = np.where(deg > 0, 1.0 / np.sqrt(deg), 0.0) L = np.eye(n) - di[:, None] * A * di[None, :] _, V = np.linalg.eigh(L) pe = V[:, 1:kpe + 1] if pe.shape[1] < kpe: pe = np.pad(pe, ((0, 0), (0, kpe - pe.shape[1]))) return torch.tensor(pe, dtype=torch.float) def featurize(graphs, pe, rwse_k): def feat(g): ei, n = g['edge_index'], g['n'] if pe == 'rwse': return rwse(ei, n, rwse_k) if pe == 'gsn': return gsn_feats(ei, n) if pe == 'sub': return sub_feats(ei, n) if pe == 'lappe': return lappe_feats(ei, n) if pe == 'all': return torch.cat([rwse(ei, n, rwse_k), gsn_feats(ei, n), sub_feats(ei, n)], dim=1) return None for g in graphs: e = feat(g) g['xin'] = torch.cat([g['rfeat'], e], dim=1) if e is not None else g['rfeat'] return graphs def deg_hist(graphs): md = 0; ds = [] for g in graphs: d = _pyg_degree(g['edge_index'][1], g['n'], dtype=torch.long) md = max(md, int(d.max()) if d.numel() else 0); ds.append(d) h = torch.zeros(md + 1, dtype=torch.long) for d in ds: h += torch.bincount(d, minlength=md + 1) return h def make_conv(conv, hidden, deg=None): if conv == 'gin': return GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) if conv == 'gcn': return GCNConv(hidden, hidden) if conv == 'sage': return SAGEConv(hidden, hidden) if conv == 'gat': return GATConv(hidden, hidden, heads=4, concat=False, add_self_loops=True) if conv == 'pna': return PNAConv(hidden, hidden, aggregators=['mean', 'min', 'max', 'std'], scalers=['identity', 'amplification', 'attenuation'], deg=deg, towers=1) if conv == 'gps': local = GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) return GPSConv(hidden, local, heads=4) raise ValueError(conv) class RecGINColor(nn.Module): def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full', sigma=0.0, conv='gin', deg=None, agg_layers=4, compute_layers=None, compute='trm', attn_heads=4): super().__init__() self.conv_type = conv self.agg_layers = agg_layers self.compute_layers = compute_layers or inner self.compute = compute self.attn_heads = attn_heads self.lin_in = nn.Linear(in_dim, hidden) self.agg_convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(agg_layers)]) self.agg_bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(agg_layers)]) if compute not in ('trm',): raise ValueError(compute) core = [] d = hidden for _ in range(self.compute_layers - 1): core += [nn.Linear(d, hidden), nn.GELU()] d = hidden core.append(nn.Linear(d, hidden)) self.core_norm = nn.LayerNorm(hidden) self.core = nn.Sequential(*core) nn.init.zeros_(self.core[-1].weight) nn.init.zeros_(self.core[-1].bias) self.head = nn.Linear(hidden, k) self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma def aggregate(self, xin, ei, batch=None): if self.conv_type == 'gps' and batch is None: batch = xin.new_zeros(xin.size(0), dtype=torch.long) h = self.lin_in(xin) for conv, bn in zip(self.agg_convs, self.agg_bns): h = conv(h, ei, batch) if self.conv_type == 'gps' else conv(h, ei) h = bn(h).relu() return h def core_step(self, combined, state): """Shared TRM compute core. Deliberately edge-free.""" return state + self.core(self.core_norm(combined)) def _z_step(self, y, z, ctx, noise): z = self.core_step(ctx + y + z, z) if noise and self.sigma > 0: z = z + self.sigma * torch.randn_like(z) return z def _y_step(self, y, z, noise): y = self.core_step(y + z, y) if noise and self.sigma > 0: y = y + self.sigma * torch.randn_like(y) return y def recurse(self, y, z, ctx, noise, one_step=False): if self.T == 0: return y, z if one_step: with torch.no_grad(): for _ in range(self.T - 1): z = self._z_step(y, z, ctx, noise) z = z.detach() z = self._z_step(y, z, ctx, noise) y = self._y_step(y, z, noise) return y, z for _ in range(self.T): z = self._z_step(y, z, ctx, noise) y = self._y_step(y, z, noise) return y, z def forward(self, xin, ei, batch=None, noise=False): ctx = self.aggregate(xin, ei, batch) y = ctx z = torch.zeros_like(ctx) outs = [] for s in range(self.n_sup): y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step')) outs.append(self.head(y)) y, z = y.detach(), z.detach() return outs def conflict_loss(logits, ei): p = F.softmax(logits, dim=-1) return (p[ei[0]] * p[ei[1]]).sum(-1).mean() @torch.no_grad() def solve_stats(model, recs, dev, sample=None): model.eval() solved = 0; conf = 0.0; tot = 0 for r in (recs[:sample] if sample else recs): ei = r['edge_index'].to(dev) col = model(r['xin'].to(dev), ei)[-1].argmax(-1) c = (col[ei[0]] == col[ei[1]]).sum().item() // 2 solved += int(c == 0); conf += c; tot += 1 return solved / tot, conf / tot def lyap1(model, xin, ei, n_steps, dev, seed=0): g = torch.Generator(device=dev).manual_seed(seed) ctx = model.aggregate(xin, ei).detach() state = torch.cat([ctx, torch.zeros_like(ctx)], dim=-1).detach() v = torch.randn(state.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12) def step_fn(ss): y, z = ss.chunk(2, dim=-1) y, z = model.recurse(y, z, ctx, noise=False) return torch.cat([y, z], dim=-1) lam = 0.0 for _ in range(n_steps): state_next, Jv = torch.autograd.functional.jvp(step_fn, state, v) state = state_next.detach(); nv = Jv.norm() lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach() return lam / n_steps def run_le(model, recs, dev, n_steps, n_graphs=300): try: from sklearn.metrics import roc_auc_score except Exception: roc_auc_score = None model.eval() lams, fails = [], [] for i, r in enumerate(recs[:n_graphs]): ei = r['edge_index'].to(dev); xin = r['xin'].to(dev) with torch.no_grad(): col = model(xin, ei)[-1].argmax(-1) c = (col[ei[0]] == col[ei[1]]).sum().item() fails.append(int(c > 0)) lams.append(lyap1(model, xin, ei, n_steps, dev, seed=i)) lams, fails = np.array(lams), np.array(fails) s, f = lams[fails == 0], lams[fails == 1] auc = (roc_auc_score(fails, lams) if roc_auc_score and len(s) and len(f) else float('nan')) sep = (f.mean() - s.mean()) if len(s) and len(f) else float('nan') print(f"[{model.conv_type}/{model.grad_mode}] LE n={len(lams)} fail={fails.mean():.2f} | " f"SOLVED {s.mean() if len(s) else float('nan'):+.3f} UNSOLVED {f.mean() if len(f) else float('nan'):+.3f}" f" sep={sep:+.3f} AUROC={auc:.3f} mean_lam={lams.mean():+.3f}") return {'n': int(len(lams)), 'fail_rate': float(fails.mean()), 'auroc': float(auc), 'sep': float(sep), 'lam_solved': (float(s.mean()) if len(s) else None), 'lam_unsolved': (float(f.mean()) if len(f) else None), 'mean_lam': float(lams.mean())} def lyap_penalty(model, x, ei, batch, target=-0.5): ctx = model.aggregate(x, ei, batch) with torch.no_grad(): yr, zr = model.recurse(ctx.detach(), torch.zeros_like(ctx).detach(), ctx.detach(), False) state = torch.cat([yr, zr], dim=-1) v = torch.randn_like(state); v = v / (v.norm() + 1e-12) def step_fn(ss): y, z = ss.chunk(2, dim=-1) y, z = model.recurse(y, z, ctx, noise=False) return torch.cat([y, z], dim=-1) _, Jv = torch.autograd.functional.jvp(step_fn, state, v, create_graph=True) return (torch.log(Jv.norm() + 1e-12) - target) ** 2 def main(): ap = argparse.ArgumentParser() ap.add_argument('--mode', choices=['train', 'le'], default='train') ap.add_argument('--conv', choices=['gin', 'gcn', 'sage', 'gat', 'pna', 'gps'], default='gin') ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') ap.add_argument('--pe', choices=['none', 'rwse', 'gsn', 'sub', 'lappe', 'all'], default='none') ap.add_argument('--contract', action='store_true') ap.add_argument('--rwse_k', type=int, default=16) ap.add_argument('--ckpt', default=None) ap.add_argument('--n', type=int, default=50); ap.add_argument('--k', type=int, default=3) ap.add_argument('--p', type=float, default=0.2); ap.add_argument('--r', type=int, default=8) ap.add_argument('--hidden', type=int, default=128); ap.add_argument('--T', type=int, default=3) ap.add_argument('--n_sup', type=int, default=3); ap.add_argument('--epochs', type=int, default=150) ap.add_argument('--agg_layers', type=int, default=4) ap.add_argument('--compute_layers', type=int, default=2) ap.add_argument('--compute', choices=['trm'], default='trm') ap.add_argument('--attn_heads', type=int, default=4) ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--bs', type=int, default=32) ap.add_argument('--seed', type=int, default=0) args = ap.parse_args() torch.manual_seed(args.seed); np.random.seed(args.seed) dev = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(OUT, exist_ok=True) if args.mode == 'le': ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg'] te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), c.get('pe', 'none'), c.get('rwse_k', 16)) deg = torch.tensor(c['deg']) if c.get('deg') else None model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'], grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg, agg_layers=c.get('agg_layers', 1), compute_layers=c.get('compute_layers', 2), compute=(c.get('compute') if c.get('compute') == 'trm' else 'trm'), attn_heads=c.get('attn_heads', 4)).to(dev) model.load_state_dict(ck['state']); model.eval() res = run_le(model, te, dev, c['n_sup'] * c['T']) base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '') with open(os.path.join(OUT, f"le_{base}.json"), 'w') as fjs: json.dump({'conv': c.get('conv', 'gin'), 'grad_mode': c['grad_mode'], 'pe': c.get('pe', 'none'), 'contract': c.get('contract', False), 'seed': c.get('seed'), 'arch': c.get('arch', 'legacy'), **res}, fjs, indent=2) return te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), args.pe, args.rwse_k) tr = featurize(make_split('train', args.n, args.k, args.p, args.r, 2000, 0), args.pe, args.rwse_k) in_dim = tr[0]['xin'].shape[1] data = [Data(x=r['xin'], edge_index=r['edge_index'], num_nodes=r['n']) for r in tr] trl = DataLoader(data, batch_size=args.bs, shuffle=True, drop_last=True) deg = deg_hist(tr) if args.conv == 'pna' else None model = RecGINColor(in_dim, args.hidden, args.k, args.T, args.n_sup, grad_mode=args.grad_mode, conv=args.conv, deg=deg, agg_layers=args.agg_layers, compute_layers=args.compute_layers, compute=args.compute, attn_heads=args.attn_heads).to(dev) opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) ema = {kk: v.detach().clone() for kk, v in model.state_dict().items()} t0 = time.time(); best_solve = -1; best = {}; best_state = None for ep in range(args.epochs): model.train() for b in trl: b = b.to(dev); opt.zero_grad() outs = model(b.x, b.edge_index, b.batch, noise=False) loss = sum(conflict_loss(o, b.edge_index) for o in outs) / len(outs) if args.contract: loss = loss + lyap_penalty(model, b.x, b.edge_index, b.batch) loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() with torch.no_grad(): for kk, v in model.state_dict().items(): if torch.is_floating_point(v): ema[kk].mul_(0.999).add_(v.detach(), alpha=0.001) else: ema[kk].copy_(v.detach()) sched.step() if (ep + 1) % 20 == 0 or ep == args.epochs - 1: backup = {kk: v.detach().clone() for kk, v in model.state_dict().items()} model.load_state_dict(ema) sr, mc = solve_stats(model, te, dev, sample=300) if sr > best_solve: best_solve = sr; best = {'ep': ep + 1, 'solve_rate': round(sr, 4), 'mean_conflicts': round(mc, 3)} best_state = {kk: ema[kk].detach().cpu().clone() for kk in ema} model.load_state_dict(backup) print(f"ep{ep+1} solve_rate={sr:.3f} mean_conflicts={mc:.2f}", flush=True) sfx = ('_ctr' if args.contract else '') tag = f"color_rrog_{args.compute}_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}" rep = {'task': 'graph3coloring', 'tag': tag, **vars(args), 'in_dim': in_dim, 'arch': 'rrog_once_agg_hidden_compute', 'sec': round(time.time() - t0, 1), **best} print(f"[{tag}] best solve_rate={best.get('solve_rate')} @ep{best.get('ep')} ({rep['sec']}s)") with open(os.path.join(OUT, f"{tag}.json"), 'w') as f: json.dump(rep, f, indent=2) torch.save({'state': best_state, 'cfg': {'in_dim': in_dim, 'hidden': args.hidden, 'k': args.k, 'T': args.T, 'n_sup': args.n_sup, 'grad_mode': args.grad_mode, 'pe': args.pe, 'rwse_k': args.rwse_k, 'contract': args.contract, 'conv': args.conv, 'seed': args.seed, 'agg_layers': args.agg_layers, 'compute_layers': args.compute_layers, 'compute': args.compute, 'attn_heads': args.attn_heads, 'arch': 'rrog_once_agg_hidden_compute', 'deg': (deg.tolist() if deg is not None else None)}}, os.path.join(OUT, f"ckpt_{tag}.pt")) print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt")) if __name__ == "__main__": main()