diff options
Diffstat (limited to 'ep_run/oracle_adjoint_train.py')
| -rw-r--r-- | ep_run/oracle_adjoint_train.py | 368 |
1 files changed, 368 insertions, 0 deletions
diff --git a/ep_run/oracle_adjoint_train.py b/ep_run/oracle_adjoint_train.py new file mode 100644 index 0000000..0503a66 --- /dev/null +++ b/ep_run/oracle_adjoint_train.py @@ -0,0 +1,368 @@ +"""Oracle exact-equilibrium-adjoint training from a redx pre-drift checkpoint. + +This deliberately trains the equilibrium objective L(z*) with an exact +matrix-free implicit adjoint: + + F_z(z*)^T lambda = -L_z(z*) + dL/dtheta = L_theta + lambda^T F_theta + +Block-parameter adjoints reuse asym_probe.py. The readout head keeps the same +local dCE/dWh path used by lt_ep_train.py. +""" + +import argparse +import math +import os +import pickle +import time +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import torch + +import lt_ep_train as L +from asym_probe import ( + Operators, + block_param_list, + ce_state_grad, + cos, + exact_transpose_grad, + flat_grad_by_param_id, + norm, + set_param_requires_grad, + solve_exact_adjoint, +) +from lt_ep_train import EQBlock, ce, relax + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt", default="runs/redx_traj/s2000.pt") + ap.add_argument("--data", default="data/tinystories_bpe") + ap.add_argument("--log-file", default="runs/oracle_adjoint.log") + ap.add_argument("--save", default="runs/oracle_adjoint.pt") + ap.add_argument("--device", default="cuda", choices=["cuda", "cpu"]) + ap.add_argument("--steps", type=int, default=1500) + ap.add_argument("--B", type=int, default=24) + ap.add_argument("--T", type=int, default=256) + ap.add_argument("--C", type=int, default=512) + ap.add_argument("--H", type=int, default=16) + ap.add_argument("--Mm", type=int, default=256) + ap.add_argument("--T1", type=int, default=150) + ap.add_argument("--eps", type=float, default=0.1) + ap.add_argument("--lr", type=float, default=6e-4) + ap.add_argument("--wd", type=float, default=1e-4) + ap.add_argument("--wsd", type=float, default=0.2) + ap.add_argument("--warmup", type=int, default=0) + ap.add_argument("--log-every", type=int, default=50) + ap.add_argument("--eval-batches", type=int, default=8) + ap.add_argument("--eval-B", type=int, default=32) + ap.add_argument("--rho-B", type=int, default=8) + ap.add_argument("--rho-steps", type=int, default=800) + ap.add_argument("--res-est", type=float, default=1e-5) + ap.add_argument("--t1max", type=int, default=6000) + ap.add_argument("--relax-chunk", type=int, default=50) + ap.add_argument("--abort-res", type=float, default=0.3) + ap.add_argument("--grad-clip", type=float, default=5.0) + ap.add_argument("--capx", type=float, default=3.0) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--adjoint-iters", type=int, default=200) + ap.add_argument("--adjoint-tol", type=float, default=1e-5) + ap.add_argument("--adjoint-mu", type=float, default=1e-4) + ap.add_argument("--solve-iters", type=int, default=80) + ap.add_argument("--solve-tol", type=float, default=1e-5) + ap.add_argument("--sanity-cos-min", type=float, default=0.999) + ap.add_argument("--tf32", action="store_true") + return ap.parse_args() + + +def require_cuda(device): + if device != "cuda": + return + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + torch.cuda.set_device(0) + return + raise SystemExit( + "ERROR: CUDA unavailable for requested GPU0 run; " + f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r}" + ) + + +def resolve_path(path): + p = Path(path) + if p.is_absolute(): + return p + return Path.cwd() / p + + +def configure_globals(cfg, dev): + L.dev = dev + L.DD = resolve_path(cfg.data) + L.vocab = pickle.load(open(L.DD / "meta.pkl", "rb"))["vocab_size"] + torch.backends.cuda.matmul.allow_tf32 = bool(cfg.tf32) + torch.backends.cudnn.allow_tf32 = bool(cfg.tf32) + + +def build_block(cfg, dev): + torch.manual_seed(cfg.seed) + blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=1.0, attn_mode="thick") + for w in blk.capw: + blk.caps[id(w)] = w.detach().norm().item() * cfg.capx + blk.qknorm = True + blk.fnoise = 0.0 + blk._cstep = None + blk.navg = 1 + blk.li_avg = 0 + blk.track = True + blk.nbrake = 0.0 + ckpt_path = resolve_path(cfg.ckpt) + ck = torch.load(ckpt_path, map_location=dev) + with torch.no_grad(): + for p, w in zip(blk.allp, ck["allp"]): + p.copy_(w.to(dev)) + return blk, ck, ckpt_path + + +@torch.no_grad() +def one_step_residual(blk, z, xin, eps): + z1 = relax(blk, z, xin, 1, eps) + return (z1 - z).norm().item() / (z.norm().item() + 1e-12) + + +def relax_refine(blk, xin, cfg): + z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) + finite_t1_res = one_step_residual(blk, z, xin, cfg.eps) + res = finite_t1_res + steps = cfg.T1 + while steps < cfg.t1max and res > cfg.res_est: + chunk = min(cfg.relax_chunk, cfg.t1max - steps) + z = relax(blk, z, xin, chunk, cfg.eps) + steps += chunk + res = one_step_residual(blk, z, xin, cfg.eps) + if not math.isfinite(res): + break + return z.detach(), finite_t1_res, res, steps + + +def oracle_grad(blk, idx, y, cfg): + xin0 = blk.embed(idx).detach() + zstar, finite_t1_res, zstar_res, relax_steps = relax_refine(blk, xin0, cfg) + + set_param_requires_grad(blk, False) + op = Operators(blk, zstar, xin0, cfg, mu=0.0) + ell, loss_zstar = ce_state_grad(blk, zstar, y) + lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg) + if adj_mu != 0.0 or gmres_info != 0 or (not math.isfinite(gmres_rel)) or gmres_rel > max(10.0 * cfg.adjoint_tol, 1e-4): + set_param_requires_grad(blk, True) + raise RuntimeError( + "exact adjoint GMRES failed " + f"(rel={gmres_rel:.3e}, info={gmres_info}, iters={gmres_iters}, tikhonov_mu={adj_mu:.3e})" + ) + + params = block_param_list(blk) + block_grads = exact_transpose_grad(blk, idx, zstar, xin0, lam, params) + grads = dict(block_grads) + with torch.enable_grad(): + (gh,) = torch.autograd.grad(ce(blk, zstar.detach(), y), blk.Wh) + grads[id(blk.Wh)] = gh + set_param_requires_grad(blk, True) + + block_flat = flat_grad_by_param_id(grads, params) + gt_flat = flat_grad_by_param_id(block_grads, params) + sanity_cos = cos(block_flat, gt_flat) + head_norm = norm(gh.detach()).item() + block_norm = norm(block_flat).item() + return grads, { + "loss_zstar": loss_zstar, + "finite_t1_res": finite_t1_res, + "zstar_res": zstar_res, + "relax_steps": relax_steps, + "gmres_rel": gmres_rel, + "gmres_info": gmres_info, + "gmres_iters": gmres_iters, + "adj_mu": adj_mu, + "sanity_cos": sanity_cos, + "block_grad_norm": block_norm, + "head_grad_norm": head_norm, + } + + +def make_optimizer_and_sched(blk, cfg): + opt = torch.optim.AdamW(blk.allp, lr=cfg.lr, weight_decay=cfg.wd) + + def lr_lambda(step): + if cfg.warmup > 0 and step < cfg.warmup: + return (step + 1) / cfg.warmup + decay_start = int((1.0 - cfg.wsd) * cfg.steps) if cfg.wsd > 0 else cfg.warmup + if step < decay_start: + return 1.0 + p = (step - decay_start) / max(1, cfg.steps - decay_start) + return 0.05 + 0.475 * (1.0 + math.cos(math.pi * min(1.0, p))) + + sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda) + return opt, sched + + +@torch.no_grad() +def apply_weight_caps(blk): + for p in blk.capw: + pn = p.norm() + cap = blk.caps[id(p)] + if pn > cap: + p.mul_(cap / pn) + + +@torch.no_grad() +def evaluate_ce(blk, cfg): + total = 0.0 + for _ in range(cfg.eval_batches): + idx, y = L.get_batch("val", cfg.eval_B, cfg.T) + xin = blk.embed(idx).detach() + z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) + total += ce(blk, z, y).item() + return total / max(1, cfg.eval_batches) + + +@torch.no_grad() +def finite_residual_on_batch(blk, idx, cfg): + xin = blk.embed(idx).detach() + z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) + return one_step_residual(blk, z, xin, cfg.eps) + + +@torch.no_grad() +def rho_decay_probe(blk, idx, cfg): + xin = blk.embed(idx).detach() + z = xin.clone() + residuals = [] + for _ in range(cfg.rho_steps): + z2 = z + cfg.eps * blk.force(z, xin).detach() + r = (z2 - z).norm().item() / (z.norm().item() + 1e-12) + residuals.append(r) + z = z2 + if (not math.isfinite(r)) or r > 1e2: + break + window = [r for r in residuals if 1e-6 < r < 1e-1] or residuals[-200:] + ratios = [window[i + 1] / window[i] for i in range(len(window) - 1) if window[i] > 0 and window[i + 1] > 0] + rho = math.exp(sum(math.log(x) for x in ratios) / len(ratios)) if ratios else float("nan") + return rho, residuals[-1] if residuals else float("nan"), len(residuals) + + +def log_line(path, line): + with open(path, "a", encoding="utf-8") as f: + f.write(line + "\n") + print(line, flush=True) + + +def track(blk, cfg, fixed_idx, step, info, t0): + val = evaluate_ce(blk, cfg) + val_res = finite_residual_on_batch(blk, fixed_idx, cfg) + rho, rho_final, rho_n = rho_decay_probe(blk, fixed_idx, cfg) + lr = info.get("lr", float("nan")) + line = ( + f"step {step:4d}/{cfg.steps} | val CE {val:.4f} | finite_T1_res {val_res:.3e} " + f"| rho800 {rho:.4f} final_res {rho_final:.2e} n={rho_n} " + f"| train_res {info.get('finite_t1_res', float('nan')):.3e} " + f"| zstar_res {info.get('zstar_res', float('nan')):.3e} relax {info.get('relax_steps', -1)} " + f"| gmres {info.get('gmres_rel', float('nan')):.2e}/{info.get('gmres_iters', -1)} " + f"| lr {lr:.3e} | {max(step, 1) / max(time.time() - t0, 1e-9):.4f} it/s" + ) + log_line(cfg.log_file, line) + return val, val_res, rho + + +def save_ckpt(blk, cfg, step, best): + if not cfg.save: + return + path = resolve_path(cfg.save) + path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + {"allp": [p.detach().cpu() for p in blk.allp], "step": step, "best": best}, + str(path) + ".tmp", + ) + os.replace(str(path) + ".tmp", path) + + +def main(): + cfg = parse_args() + require_cuda(cfg.device) + dev = torch.device("cuda:0" if cfg.device == "cuda" else "cpu") + configure_globals(cfg, dev) + cfg.log_file = str(resolve_path(cfg.log_file)) + Path(cfg.log_file).parent.mkdir(parents=True, exist_ok=True) + cfg_for_op = SimpleNamespace(**vars(cfg)) + + blk, ck, ckpt_path = build_block(cfg_for_op, dev) + opt, sched = make_optimizer_and_sched(blk, cfg_for_op) + + torch.manual_seed(1234) + fixed_idx, _ = L.get_batch("val", cfg.rho_B, cfg.T) + torch.manual_seed(cfg.seed + 1) + + header = ( + f"# oracle_adjoint_train device={dev} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r} " + f"ckpt={ckpt_path} ckpt_step={ck.get('step')} ckpt_best={ck.get('best')} " + f"B={cfg.B} T={cfg.T} C={cfg.C} H={cfg.H} Mm={cfg.Mm} T1={cfg.T1} " + f"lr={cfg.lr} wd={cfg.wd} wsd={cfg.wsd} res_est={cfg.res_est} t1max={cfg.t1max}" + ) + with open(cfg.log_file, "w", encoding="utf-8") as f: + f.write(header + "\n") + print(header, flush=True) + + t0 = time.time() + idx0, y0 = L.get_batch("train", cfg.B, cfg.T) + grads0, info0 = oracle_grad(blk, idx0, y0, cfg_for_op) + sanity = info0["sanity_cos"] + sanity_line = ( + f"step 0 sanity: cos(oracle_block_grad, asym_probe g_transpose)={sanity:+.6f} " + f"gmres_rel={info0['gmres_rel']:.3e} gmres_iters={info0['gmres_iters']} " + f"zstar_res={info0['zstar_res']:.3e} finite_T1_res={info0['finite_t1_res']:.3e}" + ) + log_line(cfg.log_file, sanity_line) + if (not math.isfinite(sanity)) or sanity < cfg.sanity_cos_min: + bug = f"STOP: step-0 oracle/asym_probe sanity cosine {sanity:+.6f} < {cfg.sanity_cos_min:.6f}" + log_line(cfg.log_file, bug) + raise SystemExit(3) + + info0["lr"] = sched.get_last_lr()[0] + best, _, _ = track(blk, cfg_for_op, fixed_idx, 0, info0, t0) + save_ckpt(blk, cfg_for_op, 0, best) + + for step in range(1, cfg.steps + 1): + idx, y = L.get_batch("train", cfg.B, cfg.T) + try: + grads, info = oracle_grad(blk, idx, y, cfg_for_op) + except RuntimeError as err: + log_line(cfg.log_file, f"ABORT step {step}: {err}") + break + if info["finite_t1_res"] > cfg.abort_res: + log_line( + cfg.log_file, + f"ABORT step {step}: finite_T1_res {info['finite_t1_res']:.3e} > {cfg.abort_res:.3e}", + ) + break + if not all((g is None) or torch.isfinite(g).all() for g in grads.values()): + log_line(cfg.log_file, f"ABORT step {step}: non-finite oracle gradient") + break + + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = grads.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, cfg.grad_clip) + opt.step() + apply_weight_caps(blk) + sched.step() + + if step % cfg.log_every == 0: + info["lr"] = sched.get_last_lr()[0] + val, _, _ = track(blk, cfg_for_op, fixed_idx, step, info, t0) + best = min(best, val) + save_ckpt(blk, cfg_for_op, step, best) + + save_ckpt(blk, cfg_for_op, step if cfg.steps > 0 else 0, best) + log_line(cfg.log_file, f"DONE best_val_CE={best:.4f}") + + +if __name__ == "__main__": + main() |
