"""Step-2: Recursive (TRM-ish) GNN on ZINC ring-counting + optional PTRM noise/selection. Recurrent shared-weight GIN block, deep-supervised over n_sup steps (TRM-style: carry latent detached between steps). --grad_mode controls the LAST supervision step's recursion: full : backprop through all T inner recursions (TRM) 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient) Optional per-step Gaussian noise (sigma) + K stochastic rollouts selected by a value head (best-Q@K) for the PTRM experiments. Saves a checkpoint for the LE diagnostic (diag/lyap.py). Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 """ import argparse, json, os, time import numpy as np import torch import torch.nn as nn from torch_geometric.loader import DataLoader from torch_geometric.data import Data from torch_geometric.nn import GINConv, global_add_pool from diag.train_cycle import prepare OUT = '/home/yurenh2/rrog/runs' def loader(recs, bs, shuffle, drop_last=False): data = [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs] return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last) class RecGIN(nn.Module): def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, grad_mode='full'): super().__init__() self.emb = nn.Embedding(n_atom, hidden) self.convs = nn.ModuleList([GINConv(nn.Sequential( nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) for _ in range(inner)]) self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)]) self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2)) self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)) self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode def block(self, z, ei): for conv, bn in zip(self.convs, self.bns): z = bn(conv(z, ei)).relu() return z def _inner(self, z, h0, ei, noise): z = self.block(z + h0, ei) if noise and self.sigma > 0: z = z + self.sigma * torch.randn_like(z) return z def recurse(self, z, h0, ei, noise, one_step=False): if one_step: # HRM 1-step gradient with torch.no_grad(): for _ in range(self.T - 1): z = self._inner(z, h0, ei, noise) z = z.detach() return self._inner(z, h0, ei, noise) # only last inner carries grad for _ in range(self.T): # TRM full recursion z = self._inner(z, h0, ei, noise) return z def forward(self, x, ei, batch, noise=False): h0 = self.emb(x) z = torch.zeros_like(h0) preds = [] for s in range(self.n_sup): if s < self.n_sup - 1: with torch.no_grad(): z = self.recurse(z, h0, ei, noise) z = z.detach() else: z = self.recurse(z, h0, ei, noise, one_step=(self.grad_mode == '1step')) preds.append(self.head(global_add_pool(z, batch))) q = self.qhead(global_add_pool(z, batch)).view(-1) return preds, q @torch.no_grad() def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'): model.eval() ysd_d, ymu_d = ysd.to(dev), ymu.to(dev) ae = torch.zeros(2); ae_or = torch.zeros(2); n = 0 for b in ld: b = b.to(dev) if K == 1: preds, _ = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) chosen = oracle = preds[-1] else: P, Q = [], [] for _ in range(K): preds, q = model(b.x, b.edge_index, b.batch, noise=True) P.append(preds[-1]); Q.append(q) P = torch.stack(P); Q = torch.stack(Q) ar = torch.arange(P.size(1), device=dev) chosen = P[Q.argmax(0), ar] if select == 'bestq' else P.mean(0) oracle = P[(P - b.y.unsqueeze(0)).abs().sum(-1).argmin(0), ar] ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() ae_or += ((oracle * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() n += b.num_graphs return (ae / n).tolist(), (ae_or / n).tolist() def main(): ap = argparse.ArgumentParser() ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') ap.add_argument('--sigma', type=float, default=0.0) ap.add_argument('--K', type=int, default=1) ap.add_argument('--select', choices=['none', 'bestq'], default='bestq') ap.add_argument('--T', type=int, default=3) ap.add_argument('--n_sup', type=int, default=3) ap.add_argument('--hidden', type=int, default=128) ap.add_argument('--epochs', type=int, default=200) ap.add_argument('--lr', type=float, default=1e-3) ap.add_argument('--bs', type=int, default=128) ap.add_argument('--lam_q', type=float, default=1.0) ap.add_argument('--seed', type=int, default=0) args = ap.parse_args() torch.manual_seed(args.seed); np.random.seed(args.seed) dev = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(OUT, exist_ok=True) tr, va, te = prepare('train'), prepare('val'), prepare('test') n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1 Ytr = torch.stack([r['y'] for r in tr]); ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 for recs in (tr, va, te): for r in recs: r['y'] = (r['y'] - ymu) / ysd trl = loader(tr, args.bs, True, drop_last=True) val, tel = loader(va, 256, False), loader(te, 256, False) model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode).to(dev) opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) l1 = nn.L1Loss() t0 = time.time(); best_val = 9e9; best = {}; best_state = None for ep in range(args.epochs): model.train() for b in trl: b = b.to(dev); opt.zero_grad() preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) loss = sum(l1(p, b.y) for p in preds) / len(preds) with torch.no_grad(): tq = -(preds[-1] - b.y).abs().mean(1) loss = loss + args.lam_q * nn.functional.mse_loss(q, tq) loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() sched.step() if (ep + 1) % 20 == 0 or ep == args.epochs - 1: vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select) if sum(vm) < best_val: best_val = sum(vm) tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select) best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo} best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True) tag = f"rec_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}_s{args.seed}" rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1), 'dev': dev, 'y_std_raw': ysd.tolist(), **best} print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} " f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)") with open(os.path.join(OUT, f"{tag}.json"), 'w') as f: json.dump(rep, f, indent=2) torch.save({'state': best_state or model.state_dict(), 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup, 'sigma': args.sigma, 'grad_mode': args.grad_mode}, 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt")) print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt")) if __name__ == "__main__": main()