summaryrefslogtreecommitdiff
path: root/diag/train_color.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/train_color.py')
-rw-r--r--diag/train_color.py142
1 files changed, 101 insertions, 41 deletions
diff --git a/diag/train_color.py b/diag/train_color.py
index 36f8496..b9d0c23 100644
--- a/diag/train_color.py
+++ b/diag/train_color.py
@@ -1,7 +1,11 @@
-"""Recursive (TRM-ish) GNN graph 3-coloring with swappable BACKBONE for the RRoG roadmap.
+"""RRoG/TRM-on-GNN graph 3-coloring.
---conv gin|gcn|sage|gat|gps : message-passing operator (gps = GraphGPS local MPNN + global
- attention = TRM's original transformer backbone, on the graph).
+The graph is encoded once into a fixed per-node context. Recursion then refines hidden state
+with a shared compute block that never reads edge_index. This is the RRoG split:
+the GNN encoder supplies the view/context x, TRM-style recurrence supplies computation.
+
+--conv gin|gcn|sage|gat|gps : message-passing operator used only by the one-shot encoder
+ (gps = GraphGPS local MPNN + global attention).
--pe none|rwse|gsn|sub|lappe|all : input structural features (random sym-break [+ encoding]).
--contract : reverse-flossing lambda-penalty during training (force contraction; roadmap #4).
--grad_mode full|1step : TRM full recursion vs HRM 1-step gradient.
@@ -142,48 +146,83 @@ def make_conv(conv, hidden, deg=None):
class RecGINColor(nn.Module):
- def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full', sigma=0.0, conv='gin', deg=None):
+ def __init__(self, in_dim, hidden, k, T=3, n_sup=3, inner=2, grad_mode='full',
+ sigma=0.0, conv='gin', deg=None, agg_layers=4, compute_layers=None,
+ compute='trm', attn_heads=4):
super().__init__()
self.conv_type = conv
+ self.agg_layers = agg_layers
+ self.compute_layers = compute_layers or inner
+ self.compute = compute
+ self.attn_heads = attn_heads
self.lin_in = nn.Linear(in_dim, hidden)
- self.convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(inner)])
- self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)])
+ self.agg_convs = nn.ModuleList([make_conv(conv, hidden, deg) for _ in range(agg_layers)])
+ self.agg_bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(agg_layers)])
+ if compute not in ('trm',):
+ raise ValueError(compute)
+ core = []
+ d = hidden
+ for _ in range(self.compute_layers - 1):
+ core += [nn.Linear(d, hidden), nn.GELU()]
+ d = hidden
+ core.append(nn.Linear(d, hidden))
+ self.core_norm = nn.LayerNorm(hidden)
+ self.core = nn.Sequential(*core)
+ nn.init.zeros_(self.core[-1].weight)
+ nn.init.zeros_(self.core[-1].bias)
self.head = nn.Linear(hidden, k)
self.T, self.n_sup, self.grad_mode, self.sigma = T, n_sup, grad_mode, sigma
- def block(self, z, ei, batch=None):
+ def aggregate(self, xin, ei, batch=None):
if self.conv_type == 'gps' and batch is None:
- batch = z.new_zeros(z.size(0), dtype=torch.long)
- for conv, bn in zip(self.convs, self.bns):
- z = conv(z, ei, batch) if self.conv_type == 'gps' else conv(z, ei)
- z = bn(z).relu()
- return z
-
- def _inner(self, z, h0, ei, noise, batch):
- z = self.block(z + h0, ei, batch)
+ batch = xin.new_zeros(xin.size(0), dtype=torch.long)
+ h = self.lin_in(xin)
+ for conv, bn in zip(self.agg_convs, self.agg_bns):
+ h = conv(h, ei, batch) if self.conv_type == 'gps' else conv(h, ei)
+ h = bn(h).relu()
+ return h
+
+ def core_step(self, combined, state):
+ """Shared TRM compute core. Deliberately edge-free."""
+ return state + self.core(self.core_norm(combined))
+
+ def _z_step(self, y, z, ctx, noise):
+ z = self.core_step(ctx + y + z, z)
if noise and self.sigma > 0:
z = z + self.sigma * torch.randn_like(z)
return z
- def recurse(self, z, h0, ei, noise, batch, one_step=False):
+ def _y_step(self, y, z, noise):
+ y = self.core_step(y + z, y)
+ if noise and self.sigma > 0:
+ y = y + self.sigma * torch.randn_like(y)
+ return y
+
+ def recurse(self, y, z, ctx, noise, one_step=False):
+ if self.T == 0:
+ return y, z
if one_step:
with torch.no_grad():
for _ in range(self.T - 1):
- z = self._inner(z, h0, ei, noise, batch)
+ z = self._z_step(y, z, ctx, noise)
z = z.detach()
- return self._inner(z, h0, ei, noise, batch)
+ z = self._z_step(y, z, ctx, noise)
+ y = self._y_step(y, z, noise)
+ return y, z
for _ in range(self.T):
- z = self._inner(z, h0, ei, noise, batch)
- return z
+ z = self._z_step(y, z, ctx, noise)
+ y = self._y_step(y, z, noise)
+ return y, z
def forward(self, xin, ei, batch=None, noise=False):
- h0 = self.lin_in(xin)
- z = torch.zeros_like(h0)
+ ctx = self.aggregate(xin, ei, batch)
+ y = ctx
+ z = torch.zeros_like(ctx)
outs = []
for s in range(self.n_sup):
- z = self.recurse(z, h0, ei, noise, batch, one_step=(self.grad_mode == '1step'))
- outs.append(self.head(z))
- z = z.detach()
+ y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step'))
+ outs.append(self.head(y))
+ y, z = y.detach(), z.detach()
return outs
@@ -206,15 +245,17 @@ def solve_stats(model, recs, dev, sample=None):
def lyap1(model, xin, ei, n_steps, dev, seed=0):
g = torch.Generator(device=dev).manual_seed(seed)
- h0 = model.lin_in(xin).detach()
- z = torch.zeros_like(h0)
- v = torch.randn(h0.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
- def step_fn(zz):
- return model.block(zz + h0, ei)
+ ctx = model.aggregate(xin, ei).detach()
+ state = torch.cat([ctx, torch.zeros_like(ctx)], dim=-1).detach()
+ v = torch.randn(state.shape, generator=g, device=dev); v = v / (v.norm() + 1e-12)
+ def step_fn(ss):
+ y, z = ss.chunk(2, dim=-1)
+ y, z = model.recurse(y, z, ctx, noise=False)
+ return torch.cat([y, z], dim=-1)
lam = 0.0
for _ in range(n_steps):
- z_next, Jv = torch.autograd.functional.jvp(step_fn, z, v)
- z = z_next.detach(); nv = Jv.norm()
+ state_next, Jv = torch.autograd.functional.jvp(step_fn, state, v)
+ state = state_next.detach(); nv = Jv.norm()
lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
return lam / n_steps
@@ -246,11 +287,16 @@ def run_le(model, recs, dev, n_steps, n_graphs=300):
def lyap_penalty(model, x, ei, batch, target=-0.5):
- h0 = model.lin_in(x)
+ ctx = model.aggregate(x, ei, batch)
with torch.no_grad():
- zr = model.recurse(torch.zeros_like(h0), h0.detach(), ei, False, batch)
- v = torch.randn_like(zr); v = v / (v.norm() + 1e-12)
- _, Jv = torch.autograd.functional.jvp(lambda zz: model.block(zz + h0, ei, batch), zr, v, create_graph=True)
+ yr, zr = model.recurse(ctx.detach(), torch.zeros_like(ctx).detach(), ctx.detach(), False)
+ state = torch.cat([yr, zr], dim=-1)
+ v = torch.randn_like(state); v = v / (v.norm() + 1e-12)
+ def step_fn(ss):
+ y, z = ss.chunk(2, dim=-1)
+ y, z = model.recurse(y, z, ctx, noise=False)
+ return torch.cat([y, z], dim=-1)
+ _, Jv = torch.autograd.functional.jvp(step_fn, state, v, create_graph=True)
return (torch.log(Jv.norm() + 1e-12) - target) ** 2
@@ -267,6 +313,10 @@ def main():
ap.add_argument('--p', type=float, default=0.2); ap.add_argument('--r', type=int, default=8)
ap.add_argument('--hidden', type=int, default=128); ap.add_argument('--T', type=int, default=3)
ap.add_argument('--n_sup', type=int, default=3); ap.add_argument('--epochs', type=int, default=150)
+ ap.add_argument('--agg_layers', type=int, default=4)
+ ap.add_argument('--compute_layers', type=int, default=2)
+ ap.add_argument('--compute', choices=['trm'], default='trm')
+ ap.add_argument('--attn_heads', type=int, default=4)
ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--bs', type=int, default=32)
ap.add_argument('--seed', type=int, default=0)
args = ap.parse_args()
@@ -280,13 +330,18 @@ def main():
c.get('pe', 'none'), c.get('rwse_k', 16))
deg = torch.tensor(c['deg']) if c.get('deg') else None
model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'],
- grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev)
+ grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg,
+ agg_layers=c.get('agg_layers', 1),
+ compute_layers=c.get('compute_layers', 2),
+ compute=(c.get('compute') if c.get('compute') == 'trm' else 'trm'),
+ attn_heads=c.get('attn_heads', 4)).to(dev)
model.load_state_dict(ck['state']); model.eval()
res = run_le(model, te, dev, c['n_sup'] * c['T'])
base = os.path.basename(args.ckpt).replace('ckpt_', '').replace('.pt', '')
with open(os.path.join(OUT, f"le_{base}.json"), 'w') as fjs:
json.dump({'conv': c.get('conv', 'gin'), 'grad_mode': c['grad_mode'], 'pe': c.get('pe', 'none'),
- 'contract': c.get('contract', False), 'seed': c.get('seed'), **res}, fjs, indent=2)
+ 'contract': c.get('contract', False), 'seed': c.get('seed'),
+ 'arch': c.get('arch', 'legacy'), **res}, fjs, indent=2)
return
te = featurize(make_split('test', args.n, args.k, args.p, args.r, 500, 100000), args.pe, args.rwse_k)
@@ -296,7 +351,9 @@ def main():
trl = DataLoader(data, batch_size=args.bs, shuffle=True, drop_last=True)
deg = deg_hist(tr) if args.conv == 'pna' else None
model = RecGINColor(in_dim, args.hidden, args.k, args.T, args.n_sup,
- grad_mode=args.grad_mode, conv=args.conv, deg=deg).to(dev)
+ grad_mode=args.grad_mode, conv=args.conv, deg=deg,
+ agg_layers=args.agg_layers, compute_layers=args.compute_layers,
+ compute=args.compute, attn_heads=args.attn_heads).to(dev)
opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
@@ -329,15 +386,18 @@ def main():
print(f"ep{ep+1} solve_rate={sr:.3f} mean_conflicts={mc:.2f}", flush=True)
sfx = ('_ctr' if args.contract else '')
- tag = f"color_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}"
+ tag = f"color_rrog_{args.compute}_{args.conv}_{args.grad_mode}_{args.pe}{sfx}_n{args.n}_k{args.k}_p{args.p}_T{args.T}_ns{args.n_sup}_s{args.seed}"
rep = {'task': 'graph3coloring', 'tag': tag, **vars(args), 'in_dim': in_dim,
- 'sec': round(time.time() - t0, 1), **best}
+ 'arch': 'rrog_once_agg_hidden_compute', 'sec': round(time.time() - t0, 1), **best}
print(f"[{tag}] best solve_rate={best.get('solve_rate')} @ep{best.get('ep')} ({rep['sec']}s)")
with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
json.dump(rep, f, indent=2)
torch.save({'state': best_state, 'cfg': {'in_dim': in_dim, 'hidden': args.hidden, 'k': args.k,
'T': args.T, 'n_sup': args.n_sup, 'grad_mode': args.grad_mode, 'pe': args.pe,
'rwse_k': args.rwse_k, 'contract': args.contract, 'conv': args.conv, 'seed': args.seed,
+ 'agg_layers': args.agg_layers, 'compute_layers': args.compute_layers,
+ 'compute': args.compute, 'attn_heads': args.attn_heads,
+ 'arch': 'rrog_once_agg_hidden_compute',
'deg': (deg.tolist() if deg is not None else None)}},
os.path.join(OUT, f"ckpt_{tag}.pt"))
print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))