"""Oracle exact-equilibrium-adjoint training from a redx pre-drift checkpoint. This deliberately trains the equilibrium objective L(z*) with an exact matrix-free implicit adjoint: F_z(z*)^T lambda = -L_z(z*) dL/dtheta = L_theta + lambda^T F_theta Block-parameter adjoints reuse asym_probe.py. The readout head keeps the same local dCE/dWh path used by lt_ep_train.py. """ import argparse import math import os import pickle import time from pathlib import Path from types import SimpleNamespace import numpy as np import torch import lt_ep_train as L from asym_probe import ( Operators, block_param_list, ce_state_grad, cos, exact_transpose_grad, flat_grad_by_param_id, norm, set_param_requires_grad, solve_exact_adjoint, ) from lt_ep_train import EQBlock, ce, relax def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", default="runs/redx_traj/s2000.pt") ap.add_argument("--data", default="data/tinystories_bpe") ap.add_argument("--log-file", default="runs/oracle_adjoint.log") ap.add_argument("--save", default="runs/oracle_adjoint.pt") ap.add_argument("--device", default="cuda", choices=["cuda", "cpu"]) ap.add_argument("--steps", type=int, default=1500) ap.add_argument("--B", type=int, default=24) ap.add_argument("--T", type=int, default=256) ap.add_argument("--C", type=int, default=512) ap.add_argument("--H", type=int, default=16) ap.add_argument("--Mm", type=int, default=256) ap.add_argument("--T1", type=int, default=150) ap.add_argument("--eps", type=float, default=0.1) ap.add_argument("--lr", type=float, default=6e-4) ap.add_argument("--wd", type=float, default=1e-4) ap.add_argument("--wsd", type=float, default=0.2) ap.add_argument("--warmup", type=int, default=0) ap.add_argument("--log-every", type=int, default=50) ap.add_argument("--eval-batches", type=int, default=8) ap.add_argument("--eval-B", type=int, default=32) ap.add_argument("--rho-B", type=int, default=8) ap.add_argument("--rho-steps", type=int, default=800) ap.add_argument("--res-est", type=float, default=1e-5) ap.add_argument("--t1max", type=int, default=6000) ap.add_argument("--relax-chunk", type=int, default=50) ap.add_argument("--abort-res", type=float, default=0.3) ap.add_argument("--grad-clip", type=float, default=5.0) ap.add_argument("--capx", type=float, default=3.0) ap.add_argument("--seed", type=int, default=0) ap.add_argument("--adjoint-iters", type=int, default=200) ap.add_argument("--adjoint-tol", type=float, default=1e-5) ap.add_argument("--adjoint-mu", type=float, default=1e-4) ap.add_argument("--solve-iters", type=int, default=80) ap.add_argument("--solve-tol", type=float, default=1e-5) ap.add_argument("--sanity-cos-min", type=float, default=0.999) ap.add_argument("--tf32", action="store_true") return ap.parse_args() def require_cuda(device): if device != "cuda": return if torch.cuda.is_available() and torch.cuda.device_count() > 0: torch.cuda.set_device(0) return raise SystemExit( "ERROR: CUDA unavailable for requested GPU0 run; " f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r}" ) def resolve_path(path): p = Path(path) if p.is_absolute(): return p return Path.cwd() / p def configure_globals(cfg, dev): L.dev = dev L.DD = resolve_path(cfg.data) L.vocab = pickle.load(open(L.DD / "meta.pkl", "rb"))["vocab_size"] torch.backends.cuda.matmul.allow_tf32 = bool(cfg.tf32) torch.backends.cudnn.allow_tf32 = bool(cfg.tf32) def build_block(cfg, dev): torch.manual_seed(cfg.seed) blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=1.0, attn_mode="thick") for w in blk.capw: blk.caps[id(w)] = w.detach().norm().item() * cfg.capx blk.qknorm = True blk.fnoise = 0.0 blk._cstep = None blk.navg = 1 blk.li_avg = 0 blk.track = True blk.nbrake = 0.0 ckpt_path = resolve_path(cfg.ckpt) ck = torch.load(ckpt_path, map_location=dev) with torch.no_grad(): for p, w in zip(blk.allp, ck["allp"]): p.copy_(w.to(dev)) return blk, ck, ckpt_path @torch.no_grad() def one_step_residual(blk, z, xin, eps): z1 = relax(blk, z, xin, 1, eps) return (z1 - z).norm().item() / (z.norm().item() + 1e-12) def relax_refine(blk, xin, cfg): z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) finite_t1_res = one_step_residual(blk, z, xin, cfg.eps) res = finite_t1_res steps = cfg.T1 while steps < cfg.t1max and res > cfg.res_est: chunk = min(cfg.relax_chunk, cfg.t1max - steps) z = relax(blk, z, xin, chunk, cfg.eps) steps += chunk res = one_step_residual(blk, z, xin, cfg.eps) if not math.isfinite(res): break return z.detach(), finite_t1_res, res, steps def oracle_grad(blk, idx, y, cfg): xin0 = blk.embed(idx).detach() zstar, finite_t1_res, zstar_res, relax_steps = relax_refine(blk, xin0, cfg) set_param_requires_grad(blk, False) op = Operators(blk, zstar, xin0, cfg, mu=0.0) ell, loss_zstar = ce_state_grad(blk, zstar, y) lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg) if adj_mu != 0.0 or gmres_info != 0 or (not math.isfinite(gmres_rel)) or gmres_rel > max(10.0 * cfg.adjoint_tol, 1e-4): set_param_requires_grad(blk, True) raise RuntimeError( "exact adjoint GMRES failed " f"(rel={gmres_rel:.3e}, info={gmres_info}, iters={gmres_iters}, tikhonov_mu={adj_mu:.3e})" ) params = block_param_list(blk) block_grads = exact_transpose_grad(blk, idx, zstar, xin0, lam, params) grads = dict(block_grads) with torch.enable_grad(): (gh,) = torch.autograd.grad(ce(blk, zstar.detach(), y), blk.Wh) grads[id(blk.Wh)] = gh set_param_requires_grad(blk, True) block_flat = flat_grad_by_param_id(grads, params) gt_flat = flat_grad_by_param_id(block_grads, params) sanity_cos = cos(block_flat, gt_flat) head_norm = norm(gh.detach()).item() block_norm = norm(block_flat).item() return grads, { "loss_zstar": loss_zstar, "finite_t1_res": finite_t1_res, "zstar_res": zstar_res, "relax_steps": relax_steps, "gmres_rel": gmres_rel, "gmres_info": gmres_info, "gmres_iters": gmres_iters, "adj_mu": adj_mu, "sanity_cos": sanity_cos, "block_grad_norm": block_norm, "head_grad_norm": head_norm, } def make_optimizer_and_sched(blk, cfg): opt = torch.optim.AdamW(blk.allp, lr=cfg.lr, weight_decay=cfg.wd) def lr_lambda(step): if cfg.warmup > 0 and step < cfg.warmup: return (step + 1) / cfg.warmup decay_start = int((1.0 - cfg.wsd) * cfg.steps) if cfg.wsd > 0 else cfg.warmup if step < decay_start: return 1.0 p = (step - decay_start) / max(1, cfg.steps - decay_start) return 0.05 + 0.475 * (1.0 + math.cos(math.pi * min(1.0, p))) sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda) return opt, sched @torch.no_grad() def apply_weight_caps(blk): for p in blk.capw: pn = p.norm() cap = blk.caps[id(p)] if pn > cap: p.mul_(cap / pn) @torch.no_grad() def evaluate_ce(blk, cfg): total = 0.0 for _ in range(cfg.eval_batches): idx, y = L.get_batch("val", cfg.eval_B, cfg.T) xin = blk.embed(idx).detach() z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) total += ce(blk, z, y).item() return total / max(1, cfg.eval_batches) @torch.no_grad() def finite_residual_on_batch(blk, idx, cfg): xin = blk.embed(idx).detach() z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) return one_step_residual(blk, z, xin, cfg.eps) @torch.no_grad() def rho_decay_probe(blk, idx, cfg): xin = blk.embed(idx).detach() z = xin.clone() residuals = [] for _ in range(cfg.rho_steps): z2 = z + cfg.eps * blk.force(z, xin).detach() r = (z2 - z).norm().item() / (z.norm().item() + 1e-12) residuals.append(r) z = z2 if (not math.isfinite(r)) or r > 1e2: break window = [r for r in residuals if 1e-6 < r < 1e-1] or residuals[-200:] ratios = [window[i + 1] / window[i] for i in range(len(window) - 1) if window[i] > 0 and window[i + 1] > 0] rho = math.exp(sum(math.log(x) for x in ratios) / len(ratios)) if ratios else float("nan") return rho, residuals[-1] if residuals else float("nan"), len(residuals) def log_line(path, line): with open(path, "a", encoding="utf-8") as f: f.write(line + "\n") print(line, flush=True) def track(blk, cfg, fixed_idx, step, info, t0): val = evaluate_ce(blk, cfg) val_res = finite_residual_on_batch(blk, fixed_idx, cfg) rho, rho_final, rho_n = rho_decay_probe(blk, fixed_idx, cfg) lr = info.get("lr", float("nan")) line = ( f"step {step:4d}/{cfg.steps} | val CE {val:.4f} | finite_T1_res {val_res:.3e} " f"| rho800 {rho:.4f} final_res {rho_final:.2e} n={rho_n} " f"| train_res {info.get('finite_t1_res', float('nan')):.3e} " f"| zstar_res {info.get('zstar_res', float('nan')):.3e} relax {info.get('relax_steps', -1)} " f"| gmres {info.get('gmres_rel', float('nan')):.2e}/{info.get('gmres_iters', -1)} " f"| lr {lr:.3e} | {max(step, 1) / max(time.time() - t0, 1e-9):.4f} it/s" ) log_line(cfg.log_file, line) return val, val_res, rho def save_ckpt(blk, cfg, step, best): if not cfg.save: return path = resolve_path(cfg.save) path.parent.mkdir(parents=True, exist_ok=True) torch.save( {"allp": [p.detach().cpu() for p in blk.allp], "step": step, "best": best}, str(path) + ".tmp", ) os.replace(str(path) + ".tmp", path) def main(): cfg = parse_args() require_cuda(cfg.device) dev = torch.device("cuda:0" if cfg.device == "cuda" else "cpu") configure_globals(cfg, dev) cfg.log_file = str(resolve_path(cfg.log_file)) Path(cfg.log_file).parent.mkdir(parents=True, exist_ok=True) cfg_for_op = SimpleNamespace(**vars(cfg)) blk, ck, ckpt_path = build_block(cfg_for_op, dev) opt, sched = make_optimizer_and_sched(blk, cfg_for_op) torch.manual_seed(1234) fixed_idx, _ = L.get_batch("val", cfg.rho_B, cfg.T) torch.manual_seed(cfg.seed + 1) header = ( f"# oracle_adjoint_train device={dev} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r} " f"ckpt={ckpt_path} ckpt_step={ck.get('step')} ckpt_best={ck.get('best')} " f"B={cfg.B} T={cfg.T} C={cfg.C} H={cfg.H} Mm={cfg.Mm} T1={cfg.T1} " f"lr={cfg.lr} wd={cfg.wd} wsd={cfg.wsd} res_est={cfg.res_est} t1max={cfg.t1max}" ) with open(cfg.log_file, "w", encoding="utf-8") as f: f.write(header + "\n") print(header, flush=True) t0 = time.time() idx0, y0 = L.get_batch("train", cfg.B, cfg.T) grads0, info0 = oracle_grad(blk, idx0, y0, cfg_for_op) sanity = info0["sanity_cos"] sanity_line = ( f"step 0 sanity: cos(oracle_block_grad, asym_probe g_transpose)={sanity:+.6f} " f"gmres_rel={info0['gmres_rel']:.3e} gmres_iters={info0['gmres_iters']} " f"zstar_res={info0['zstar_res']:.3e} finite_T1_res={info0['finite_t1_res']:.3e}" ) log_line(cfg.log_file, sanity_line) if (not math.isfinite(sanity)) or sanity < cfg.sanity_cos_min: bug = f"STOP: step-0 oracle/asym_probe sanity cosine {sanity:+.6f} < {cfg.sanity_cos_min:.6f}" log_line(cfg.log_file, bug) raise SystemExit(3) info0["lr"] = sched.get_last_lr()[0] best, _, _ = track(blk, cfg_for_op, fixed_idx, 0, info0, t0) save_ckpt(blk, cfg_for_op, 0, best) for step in range(1, cfg.steps + 1): idx, y = L.get_batch("train", cfg.B, cfg.T) try: grads, info = oracle_grad(blk, idx, y, cfg_for_op) except RuntimeError as err: log_line(cfg.log_file, f"ABORT step {step}: {err}") break if info["finite_t1_res"] > cfg.abort_res: log_line( cfg.log_file, f"ABORT step {step}: finite_T1_res {info['finite_t1_res']:.3e} > {cfg.abort_res:.3e}", ) break if not all((g is None) or torch.isfinite(g).all() for g in grads.values()): log_line(cfg.log_file, f"ABORT step {step}: non-finite oracle gradient") break opt.zero_grad(set_to_none=True) for p in blk.allp: p.grad = grads.get(id(p)) torch.nn.utils.clip_grad_norm_(blk.allp, cfg.grad_clip) opt.step() apply_weight_caps(blk) sched.step() if step % cfg.log_every == 0: info["lr"] = sched.get_last_lr()[0] val, _, _ = track(blk, cfg_for_op, fixed_idx, step, info, t0) best = min(best, val) save_ckpt(blk, cfg_for_op, step, best) save_ckpt(blk, cfg_for_op, step if cfg.steps > 0 else 0, best) log_line(cfg.log_file, f"DONE best_val_CE={best:.4f}") if __name__ == "__main__": main()