diff options
Diffstat (limited to 'diag/train_rec.py')
| -rw-r--r-- | diag/train_rec.py | 491 |
1 files changed, 491 insertions, 0 deletions
diff --git a/diag/train_rec.py b/diag/train_rec.py new file mode 100644 index 0000000..9db7eb1 --- /dev/null +++ b/diag/train_rec.py @@ -0,0 +1,491 @@ +"""Step-2: RRoG/TRM-on-GNN for ZINC ring-counting. + +The graph is encoded once with a GIN encoder. A shared edge-free node-wise compute block then +refines hidden state over n_sup*T recurrent steps (TRM-style: carry latent detached between +deep-supervision steps). --grad_mode controls the LAST supervision step's recursion: + full : backprop through all T inner recursions (TRM) + 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient) + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +from torch_geometric.loader import DataLoader +from torch_geometric.data import Batch, Data +from torch_geometric.nn import ( + APPNP, + ARMAConv, + ChebConv, + FiLMConv, + GATv2Conv, + GCNConv, + GENConv, + GINEConv, + GINConv, + GraphConv, + MFConv, + PNAConv, + ResGatedGraphConv, + SAGEConv, + SGConv, + TAGConv, + TransformerConv, + global_add_pool, +) +from torch_geometric.utils import degree +from diag.train_cycle import prepare + +PROJECT_ROOT = os.environ.get( + 'RROG_ROOT', + os.path.abspath(os.path.join(os.path.dirname(__file__), '..')), +) +OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs')) +SUPPORTED_VIEWS = [ + 'gin', 'gine', 'gcn', 'graphsage', 'gatv2', 'graphconv', 'transformer', 'pna', + 'gen', 'film', 'resgated', 'tag', 'sgc', 'cheb', 'arma', 'mf', 'appnp', +] + + +def data_list(recs): + return [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2), + num_nodes=r['x'].numel()) for r in recs] + + +def loader(recs, bs, shuffle, drop_last=False): + data = recs if recs and isinstance(recs[0], Data) else data_list(recs) + return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last) + + +def degree_histogram(data): + max_degree = 0 + degs = [] + for graph in data: + deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long) + degs.append(deg) + if deg.numel(): + max_degree = max(max_degree, int(deg.max().item())) + hist = torch.zeros(max_degree + 1, dtype=torch.long) + for deg in degs: + hist += torch.bincount(deg, minlength=hist.numel()) + return hist + + +def make_view_layer(view, hidden, deg): + if view == 'gin': + return GINConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) + if view == 'gine': + return GINEConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), + train_eps=True, edge_dim=hidden) + if view == 'gcn': + return GCNConv(hidden, hidden) + if view == 'graphsage': + return SAGEConv(hidden, hidden) + if view == 'gatv2': + return GATv2Conv(hidden, hidden, heads=4, concat=False) + if view == 'graphconv': + return GraphConv(hidden, hidden) + if view == 'transformer': + return TransformerConv(hidden, hidden, heads=4, concat=False) + if view == 'pna': + if deg is None: + raise ValueError('PNA view requires a training-set degree histogram') + return PNAConv( + hidden, hidden, + aggregators=['mean', 'min', 'max', 'std'], + scalers=['identity', 'amplification', 'attenuation'], + deg=deg, + ) + if view == 'gen': + return GENConv(hidden, hidden) + if view == 'film': + return FiLMConv(hidden, hidden) + if view == 'resgated': + return ResGatedGraphConv(hidden, hidden) + if view == 'tag': + return TAGConv(hidden, hidden, K=3) + if view == 'sgc': + return SGConv(hidden, hidden, K=2, cached=False) + if view == 'cheb': + return ChebConv(hidden, hidden, K=3) + if view == 'arma': + return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2) + if view == 'mf': + return MFConv(hidden, hidden) + if view == 'appnp': + return APPNP(K=5, alpha=0.1) + raise ValueError(f'unsupported view: {view}') + + +class RecGIN(nn.Module): + def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, + grad_mode='full', agg_layers=5, compute_layers=None, view='gin', deg=None): + super().__init__() + self.view = view + self.agg_layers = agg_layers + self.compute_layers = compute_layers or inner + self.emb = nn.Embedding(n_atom, hidden) + self.edge_emb = nn.Embedding(1, hidden) if view == 'gine' else None + self.agg_convs = nn.ModuleList() + for _ in range(agg_layers): + self.agg_convs.append(make_view_layer(view, hidden, deg)) + self.agg_bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(agg_layers)]) + core = [] + d = hidden + for _ in range(self.compute_layers - 1): + core += [nn.Linear(d, hidden), nn.GELU()] + d = hidden + core.append(nn.Linear(d, hidden)) + self.core_norm = nn.LayerNorm(hidden) + self.core = nn.Sequential(*core) + nn.init.zeros_(self.core[-1].weight) + nn.init.zeros_(self.core[-1].bias) + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2)) + self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)) + with torch.no_grad(): + self.qhead[-1].weight.zero_() + self.qhead[-1].bias.fill_(-5.0) + self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode + + def aggregate(self, x, ei): + h = self.emb(x) + for conv, bn in zip(self.agg_convs, self.agg_bns): + if self.view == 'gine': + edge_attr = self.edge_emb(torch.zeros(ei.size(1), dtype=torch.long, device=ei.device)) + h = bn(conv(h, ei, edge_attr)).relu() + else: + h = bn(conv(h, ei)).relu() + return h + + def core_step(self, combined, state): + """Shared TRM compute core. Deliberately edge-free.""" + return state + self.core(self.core_norm(combined)) + + def _z_step(self, y, z, ctx, noise): + z = self.core_step(ctx + y + z, z) + if noise and self.sigma > 0: + z = z + self.sigma * torch.randn_like(z) + return z + + def _y_step(self, y, z, noise): + y = self.core_step(y + z, y) + if noise and self.sigma > 0: + y = y + self.sigma * torch.randn_like(y) + return y + + def recurse(self, y, z, ctx, noise, one_step=False): + if self.T == 0: + return y, z + if one_step: # HRM 1-step gradient + with torch.no_grad(): + for _ in range(self.T - 1): + z = self._z_step(y, z, ctx, noise) + z = z.detach() + z = self._z_step(y, z, ctx, noise) # only last inner carries grad + y = self._y_step(y, z, noise) + return y, z + for _ in range(self.T): # TRM full recursion + z = self._z_step(y, z, ctx, noise) + y = self._y_step(y, z, noise) + return y, z + + def predict(self, y, batch): + pooled = global_add_pool(y, batch) + return self.head(pooled), self.qhead(pooled).view(-1) + + def forward_trace(self, x, ei, batch, steps, noise=False): + ctx = self.aggregate(x, ei) + y = ctx + z = torch.zeros_like(ctx) + preds, q_logits = [], [] + for s in range(steps): + y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step')) + pred, q = self.predict(y, batch) + preds.append(pred) + q_logits.append(q) + if s < steps - 1: + y, z = y.detach(), z.detach() + return preds, q_logits + + def forward(self, x, ei, batch, noise=False): + ctx = self.aggregate(x, ei) + y = ctx + z = torch.zeros_like(ctx) + preds = [] + for s in range(self.n_sup): + if s < self.n_sup - 1: + with torch.no_grad(): + y, z = self.recurse(y, z, ctx, noise) + y, z = y.detach(), z.detach() + else: + y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step')) + pred, _ = self.predict(y, batch) + preds.append(pred) + _, q = self.predict(y, batch) + return preds, q + + +@torch.no_grad() +def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'): + model.eval() + ysd_d, ymu_d = ysd.to(dev), ymu.to(dev) + ae = torch.zeros(2); ae_or = torch.zeros(2); n = 0 + for b in ld: + b = b.to(dev) + if K == 1: + preds, _ = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) + chosen = oracle = preds[-1] + else: + P, Q = [], [] + for _ in range(K): + preds, q = model(b.x, b.edge_index, b.batch, noise=True) + P.append(preds[-1]); Q.append(q) + P = torch.stack(P); Q = torch.stack(Q) + ar = torch.arange(P.size(1), device=dev) + chosen = P[Q.argmax(0), ar] if select == 'bestq' else P.mean(0) + oracle = P[(P - b.y.unsqueeze(0)).abs().sum(-1).argmin(0), ar] + ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + ae_or += ((oracle * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + n += b.num_graphs + return (ae / n).tolist(), (ae_or / n).tolist() + + +@torch.no_grad() +def evaluate_trace(model, ld, dev, ymu, ysd, steps, adaptive=False): + model.eval() + ysd_d, ymu_d = ysd.to(dev), ymu.to(dev) + ae = torch.zeros(2) + n = 0 + step_sum = 0.0 + for b in ld: + b = b.to(dev) + preds, q_logits = model.forward_trace(b.x, b.edge_index, b.batch, steps, noise=False) + P = torch.stack(preds, dim=0) + if adaptive: + Q = torch.stack(q_logits, dim=0) + halted = Q > 0 + any_halt = halted.any(dim=0) + first_halt = halted.to(torch.int64).argmax(dim=0) + fallback = torch.full_like(first_halt, steps - 1) + idx = torch.where(any_halt, first_halt, fallback) + chosen = P[idx, torch.arange(P.size(1), device=dev)] + step_sum += (idx.to(torch.float32) + 1).sum().item() + else: + chosen = P[-1] + step_sum += steps * b.num_graphs + ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + n += b.num_graphs + return (ae / n).tolist(), step_sum / max(n, 1) + + +def _split_nodes(t, ptr): + return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)] + + +def act_train_step(model, state, replacement_batch, opt, dev, args): + replacement = replacement_batch.to_data_list() + batch_size = len(replacement) + if state is None: + state = { + 'graphs': [None for _ in range(batch_size)], + 'y': [None for _ in range(batch_size)], + 'z': [None for _ in range(batch_size)], + 'steps': torch.zeros(batch_size, dtype=torch.long, device=dev), + 'halted': torch.ones(batch_size, dtype=torch.bool, device=dev), + } + + halted_cpu = state['halted'].detach().cpu().tolist() + for i, halted in enumerate(halted_cpu): + if halted: + state['graphs'][i] = replacement[i] + + b = Batch.from_data_list(state['graphs']).to(dev) + ctx = model.aggregate(b.x, b.edge_index) + ptr = b.ptr + y_parts, z_parts = [], [] + for i in range(batch_size): + start, end = ptr[i].item(), ptr[i + 1].item() + if halted_cpu[i] or state['y'][i] is None: + y_parts.append(ctx[start:end]) + z_parts.append(torch.zeros_like(ctx[start:end])) + else: + y_parts.append(state['y'][i].to(dev)) + z_parts.append(state['z'][i].to(dev)) + y = torch.cat(y_parts, dim=0) + z = torch.cat(z_parts, dim=0) + + opt.zero_grad() + y, z = model.recurse(y, z, ctx, noise=False, one_step=(model.grad_mode == '1step')) + pred, q = model.predict(y, b.batch) + per_graph_err = (pred - b.y).abs().mean(1) + pred_loss = per_graph_err.mean() + with torch.no_grad(): + if args.halt_target == 'binary': + halt_target = (per_graph_err <= args.halt_norm_threshold).to(q.dtype) + else: + halt_target = torch.sigmoid((args.halt_norm_threshold - per_graph_err) / args.halt_temp) + q_loss = nn.functional.binary_cross_entropy_with_logits(q, halt_target) + loss = pred_loss + 0.5 * args.lam_q * q_loss + y_det, z_det = y.detach(), z.detach() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + + state['y'] = _split_nodes(y_det, ptr) + state['z'] = _split_nodes(z_det, ptr) + with torch.no_grad(): + was_halted = state['halted'] + steps = torch.where(was_halted, torch.zeros_like(state['steps']), state['steps']) + 1 + halted = (steps >= args.halt_max_steps) | (q.detach() > 0) + if args.halt_exploration_prob > 0 and args.halt_max_steps > 1: + explore = torch.rand_like(q) < args.halt_exploration_prob + min_steps = torch.where( + explore, + torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev), + torch.zeros_like(steps), + ) + halted = halted & (steps >= min_steps) + state['steps'] = steps + state['halted'] = halted + + return state, { + 'loss': float(loss.detach().cpu()), + 'pred_loss': float(pred_loss.detach().cpu()), + 'q_loss': float(q_loss.detach().cpu()), + 'halted_frac': float(state['halted'].to(torch.float32).mean().detach().cpu()), + 'steps': float(state['steps'].to(torch.float32).mean().detach().cpu()), + } + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--sigma', type=float, default=0.0) + ap.add_argument('--K', type=int, default=1) + ap.add_argument('--select', choices=['none', 'bestq'], default='bestq') + ap.add_argument('--T', type=int, default=3) + ap.add_argument('--n_sup', type=int, default=3) + ap.add_argument('--hidden', type=int, default=128) + ap.add_argument('--agg_layers', type=int, default=5) + ap.add_argument('--compute_layers', type=int, default=2) + ap.add_argument('--view', choices=SUPPORTED_VIEWS, default='gin') + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--bs', type=int, default=128) + ap.add_argument('--lam_q', type=float, default=1.0) + ap.add_argument('--act', action='store_true', + help='train all recurrent depths up to halt_max_steps and train qhead as a halt head') + ap.add_argument('--halt_max_steps', type=int, default=8) + ap.add_argument('--halt_norm_threshold', type=float, default=0.30) + ap.add_argument('--halt_temp', type=float, default=0.10) + ap.add_argument('--halt_target', choices=['soft', 'binary'], default='soft') + ap.add_argument('--halt_exploration_prob', type=float, default=0.1) + ap.add_argument('--loss_mode', choices=['last', 'trace'], default='trace') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--device', default='auto') + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if args.device == 'auto' and torch.cuda.is_available() else ( + 'cpu' if args.device == 'auto' else args.device) + os.makedirs(OUT, exist_ok=True) + + tr, va, te = prepare('train'), prepare('val'), prepare('test') + n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1 + Ytr = torch.stack([r['y'] for r in tr]); ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 + for recs in (tr, va, te): + for r in recs: + r['y'] = (r['y'] - ymu) / ysd + train_data = data_list(tr) + trl = loader(train_data, args.bs, True, drop_last=True) + val, tel = loader(va, 256, False), loader(te, 256, False) + + deg = degree_histogram(train_data) if args.view == 'pna' else None + model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode, + agg_layers=args.agg_layers, compute_layers=args.compute_layers, + view=args.view, deg=deg).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + l1 = nn.L1Loss() + act_steps = max(1, args.halt_max_steps) + + t0 = time.time(); best_val = 9e9; best = {}; best_state = None; act_state = None + for ep in range(args.epochs): + model.train() + act_metrics = [] + for b in trl: + if args.act: + act_state, metrics = act_train_step(model, act_state, b, opt, dev, args) + act_metrics.append(metrics) + else: + b = b.to(dev); opt.zero_grad() + if args.loss_mode == 'trace': + preds, q_logits = model.forward_trace( + b.x, b.edge_index, b.batch, args.n_sup, noise=model.sigma > 0) + q = q_logits[-1] + else: + preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) + loss = sum(l1(p, b.y) for p in preds) / len(preds) + with torch.no_grad(): + tq = -(preds[-1] - b.y).abs().mean(1) + loss = loss + args.lam_q * nn.functional.mse_loss(q, tq) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + if args.act: + vm, _ = evaluate_trace(model, val, dev, ymu, ysd, act_steps, adaptive=False) + else: + vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select) + if sum(vm) < best_val: + best_val = sum(vm) + if args.act: + tem, fixed_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=False) + tea, adaptive_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=True) + best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, + 'test_mae_adaptive': tea, 'fixed_steps': fixed_steps, + 'adaptive_steps': adaptive_steps} + else: + tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select) + best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo} + best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + if args.act and act_metrics: + hm = sum(m['halted_frac'] for m in act_metrics) / len(act_metrics) + sm = sum(m['steps'] for m in act_metrics) / len(act_metrics) + print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]} halt={hm:.2f} train_steps={sm:.2f}", flush=True) + else: + print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True) + + act_tag = f"_actfull{act_steps}_{args.halt_target}{args.halt_norm_threshold:g}_e{args.epochs}" if args.act else "" + loss_tag = f"_{args.loss_mode}" if (not args.act and args.loss_mode != 'last') else "" + view_tag = f"_{args.view}" if args.view != 'gin' else "" + tag = f"rec_rrog{view_tag}_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}{loss_tag}{act_tag}_s{args.seed}" + rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1), + 'dev': dev, 'arch': 'rrog_once_agg_node_compute', 'y_std_raw': ysd.tolist(), **best} + if args.act: + print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} " + f"adaptive={[round(x,3) for x in best.get('test_mae_adaptive')]} " + f"steps={best.get('adaptive_steps'):.2f}/{best.get('fixed_steps'):.2f} " + f"@ep{best.get('ep')} ({rep['sec']}s)") + else: + print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} " + f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)") + with open(os.path.join(OUT, f"{tag}.json"), 'w') as f: + json.dump(rep, f, indent=2) + torch.save({'state': best_state or model.state_dict(), + 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup, + 'sigma': args.sigma, 'grad_mode': args.grad_mode, + 'agg_layers': args.agg_layers, 'compute_layers': args.compute_layers, + 'view': args.view, + 'loss_mode': args.loss_mode, + 'act': args.act, 'act_impl': 'persistent_recycle' if args.act else 'none', + 'halt_max_steps': act_steps, + 'halt_exploration_prob': args.halt_exploration_prob, + 'arch': 'rrog_once_agg_node_compute'}, + 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt")) + print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt")) + + +if __name__ == "__main__": + main() |
