summaryrefslogtreecommitdiff
path: root/ep_run/oracle_adjoint_train.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/oracle_adjoint_train.py')
-rw-r--r--ep_run/oracle_adjoint_train.py368
1 files changed, 368 insertions, 0 deletions
diff --git a/ep_run/oracle_adjoint_train.py b/ep_run/oracle_adjoint_train.py
new file mode 100644
index 0000000..0503a66
--- /dev/null
+++ b/ep_run/oracle_adjoint_train.py
@@ -0,0 +1,368 @@
+"""Oracle exact-equilibrium-adjoint training from a redx pre-drift checkpoint.
+
+This deliberately trains the equilibrium objective L(z*) with an exact
+matrix-free implicit adjoint:
+
+ F_z(z*)^T lambda = -L_z(z*)
+ dL/dtheta = L_theta + lambda^T F_theta
+
+Block-parameter adjoints reuse asym_probe.py. The readout head keeps the same
+local dCE/dWh path used by lt_ep_train.py.
+"""
+
+import argparse
+import math
+import os
+import pickle
+import time
+from pathlib import Path
+from types import SimpleNamespace
+
+import numpy as np
+import torch
+
+import lt_ep_train as L
+from asym_probe import (
+ Operators,
+ block_param_list,
+ ce_state_grad,
+ cos,
+ exact_transpose_grad,
+ flat_grad_by_param_id,
+ norm,
+ set_param_requires_grad,
+ solve_exact_adjoint,
+)
+from lt_ep_train import EQBlock, ce, relax
+
+
+def parse_args():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--ckpt", default="runs/redx_traj/s2000.pt")
+ ap.add_argument("--data", default="data/tinystories_bpe")
+ ap.add_argument("--log-file", default="runs/oracle_adjoint.log")
+ ap.add_argument("--save", default="runs/oracle_adjoint.pt")
+ ap.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
+ ap.add_argument("--steps", type=int, default=1500)
+ ap.add_argument("--B", type=int, default=24)
+ ap.add_argument("--T", type=int, default=256)
+ ap.add_argument("--C", type=int, default=512)
+ ap.add_argument("--H", type=int, default=16)
+ ap.add_argument("--Mm", type=int, default=256)
+ ap.add_argument("--T1", type=int, default=150)
+ ap.add_argument("--eps", type=float, default=0.1)
+ ap.add_argument("--lr", type=float, default=6e-4)
+ ap.add_argument("--wd", type=float, default=1e-4)
+ ap.add_argument("--wsd", type=float, default=0.2)
+ ap.add_argument("--warmup", type=int, default=0)
+ ap.add_argument("--log-every", type=int, default=50)
+ ap.add_argument("--eval-batches", type=int, default=8)
+ ap.add_argument("--eval-B", type=int, default=32)
+ ap.add_argument("--rho-B", type=int, default=8)
+ ap.add_argument("--rho-steps", type=int, default=800)
+ ap.add_argument("--res-est", type=float, default=1e-5)
+ ap.add_argument("--t1max", type=int, default=6000)
+ ap.add_argument("--relax-chunk", type=int, default=50)
+ ap.add_argument("--abort-res", type=float, default=0.3)
+ ap.add_argument("--grad-clip", type=float, default=5.0)
+ ap.add_argument("--capx", type=float, default=3.0)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--adjoint-iters", type=int, default=200)
+ ap.add_argument("--adjoint-tol", type=float, default=1e-5)
+ ap.add_argument("--adjoint-mu", type=float, default=1e-4)
+ ap.add_argument("--solve-iters", type=int, default=80)
+ ap.add_argument("--solve-tol", type=float, default=1e-5)
+ ap.add_argument("--sanity-cos-min", type=float, default=0.999)
+ ap.add_argument("--tf32", action="store_true")
+ return ap.parse_args()
+
+
+def require_cuda(device):
+ if device != "cuda":
+ return
+ if torch.cuda.is_available() and torch.cuda.device_count() > 0:
+ torch.cuda.set_device(0)
+ return
+ raise SystemExit(
+ "ERROR: CUDA unavailable for requested GPU0 run; "
+ f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r}"
+ )
+
+
+def resolve_path(path):
+ p = Path(path)
+ if p.is_absolute():
+ return p
+ return Path.cwd() / p
+
+
+def configure_globals(cfg, dev):
+ L.dev = dev
+ L.DD = resolve_path(cfg.data)
+ L.vocab = pickle.load(open(L.DD / "meta.pkl", "rb"))["vocab_size"]
+ torch.backends.cuda.matmul.allow_tf32 = bool(cfg.tf32)
+ torch.backends.cudnn.allow_tf32 = bool(cfg.tf32)
+
+
+def build_block(cfg, dev):
+ torch.manual_seed(cfg.seed)
+ blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=1.0, attn_mode="thick")
+ for w in blk.capw:
+ blk.caps[id(w)] = w.detach().norm().item() * cfg.capx
+ blk.qknorm = True
+ blk.fnoise = 0.0
+ blk._cstep = None
+ blk.navg = 1
+ blk.li_avg = 0
+ blk.track = True
+ blk.nbrake = 0.0
+ ckpt_path = resolve_path(cfg.ckpt)
+ ck = torch.load(ckpt_path, map_location=dev)
+ with torch.no_grad():
+ for p, w in zip(blk.allp, ck["allp"]):
+ p.copy_(w.to(dev))
+ return blk, ck, ckpt_path
+
+
+@torch.no_grad()
+def one_step_residual(blk, z, xin, eps):
+ z1 = relax(blk, z, xin, 1, eps)
+ return (z1 - z).norm().item() / (z.norm().item() + 1e-12)
+
+
+def relax_refine(blk, xin, cfg):
+ z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps)
+ finite_t1_res = one_step_residual(blk, z, xin, cfg.eps)
+ res = finite_t1_res
+ steps = cfg.T1
+ while steps < cfg.t1max and res > cfg.res_est:
+ chunk = min(cfg.relax_chunk, cfg.t1max - steps)
+ z = relax(blk, z, xin, chunk, cfg.eps)
+ steps += chunk
+ res = one_step_residual(blk, z, xin, cfg.eps)
+ if not math.isfinite(res):
+ break
+ return z.detach(), finite_t1_res, res, steps
+
+
+def oracle_grad(blk, idx, y, cfg):
+ xin0 = blk.embed(idx).detach()
+ zstar, finite_t1_res, zstar_res, relax_steps = relax_refine(blk, xin0, cfg)
+
+ set_param_requires_grad(blk, False)
+ op = Operators(blk, zstar, xin0, cfg, mu=0.0)
+ ell, loss_zstar = ce_state_grad(blk, zstar, y)
+ lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg)
+ if adj_mu != 0.0 or gmres_info != 0 or (not math.isfinite(gmres_rel)) or gmres_rel > max(10.0 * cfg.adjoint_tol, 1e-4):
+ set_param_requires_grad(blk, True)
+ raise RuntimeError(
+ "exact adjoint GMRES failed "
+ f"(rel={gmres_rel:.3e}, info={gmres_info}, iters={gmres_iters}, tikhonov_mu={adj_mu:.3e})"
+ )
+
+ params = block_param_list(blk)
+ block_grads = exact_transpose_grad(blk, idx, zstar, xin0, lam, params)
+ grads = dict(block_grads)
+ with torch.enable_grad():
+ (gh,) = torch.autograd.grad(ce(blk, zstar.detach(), y), blk.Wh)
+ grads[id(blk.Wh)] = gh
+ set_param_requires_grad(blk, True)
+
+ block_flat = flat_grad_by_param_id(grads, params)
+ gt_flat = flat_grad_by_param_id(block_grads, params)
+ sanity_cos = cos(block_flat, gt_flat)
+ head_norm = norm(gh.detach()).item()
+ block_norm = norm(block_flat).item()
+ return grads, {
+ "loss_zstar": loss_zstar,
+ "finite_t1_res": finite_t1_res,
+ "zstar_res": zstar_res,
+ "relax_steps": relax_steps,
+ "gmres_rel": gmres_rel,
+ "gmres_info": gmres_info,
+ "gmres_iters": gmres_iters,
+ "adj_mu": adj_mu,
+ "sanity_cos": sanity_cos,
+ "block_grad_norm": block_norm,
+ "head_grad_norm": head_norm,
+ }
+
+
+def make_optimizer_and_sched(blk, cfg):
+ opt = torch.optim.AdamW(blk.allp, lr=cfg.lr, weight_decay=cfg.wd)
+
+ def lr_lambda(step):
+ if cfg.warmup > 0 and step < cfg.warmup:
+ return (step + 1) / cfg.warmup
+ decay_start = int((1.0 - cfg.wsd) * cfg.steps) if cfg.wsd > 0 else cfg.warmup
+ if step < decay_start:
+ return 1.0
+ p = (step - decay_start) / max(1, cfg.steps - decay_start)
+ return 0.05 + 0.475 * (1.0 + math.cos(math.pi * min(1.0, p)))
+
+ sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
+ return opt, sched
+
+
+@torch.no_grad()
+def apply_weight_caps(blk):
+ for p in blk.capw:
+ pn = p.norm()
+ cap = blk.caps[id(p)]
+ if pn > cap:
+ p.mul_(cap / pn)
+
+
+@torch.no_grad()
+def evaluate_ce(blk, cfg):
+ total = 0.0
+ for _ in range(cfg.eval_batches):
+ idx, y = L.get_batch("val", cfg.eval_B, cfg.T)
+ xin = blk.embed(idx).detach()
+ z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps)
+ total += ce(blk, z, y).item()
+ return total / max(1, cfg.eval_batches)
+
+
+@torch.no_grad()
+def finite_residual_on_batch(blk, idx, cfg):
+ xin = blk.embed(idx).detach()
+ z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps)
+ return one_step_residual(blk, z, xin, cfg.eps)
+
+
+@torch.no_grad()
+def rho_decay_probe(blk, idx, cfg):
+ xin = blk.embed(idx).detach()
+ z = xin.clone()
+ residuals = []
+ for _ in range(cfg.rho_steps):
+ z2 = z + cfg.eps * blk.force(z, xin).detach()
+ r = (z2 - z).norm().item() / (z.norm().item() + 1e-12)
+ residuals.append(r)
+ z = z2
+ if (not math.isfinite(r)) or r > 1e2:
+ break
+ window = [r for r in residuals if 1e-6 < r < 1e-1] or residuals[-200:]
+ ratios = [window[i + 1] / window[i] for i in range(len(window) - 1) if window[i] > 0 and window[i + 1] > 0]
+ rho = math.exp(sum(math.log(x) for x in ratios) / len(ratios)) if ratios else float("nan")
+ return rho, residuals[-1] if residuals else float("nan"), len(residuals)
+
+
+def log_line(path, line):
+ with open(path, "a", encoding="utf-8") as f:
+ f.write(line + "\n")
+ print(line, flush=True)
+
+
+def track(blk, cfg, fixed_idx, step, info, t0):
+ val = evaluate_ce(blk, cfg)
+ val_res = finite_residual_on_batch(blk, fixed_idx, cfg)
+ rho, rho_final, rho_n = rho_decay_probe(blk, fixed_idx, cfg)
+ lr = info.get("lr", float("nan"))
+ line = (
+ f"step {step:4d}/{cfg.steps} | val CE {val:.4f} | finite_T1_res {val_res:.3e} "
+ f"| rho800 {rho:.4f} final_res {rho_final:.2e} n={rho_n} "
+ f"| train_res {info.get('finite_t1_res', float('nan')):.3e} "
+ f"| zstar_res {info.get('zstar_res', float('nan')):.3e} relax {info.get('relax_steps', -1)} "
+ f"| gmres {info.get('gmres_rel', float('nan')):.2e}/{info.get('gmres_iters', -1)} "
+ f"| lr {lr:.3e} | {max(step, 1) / max(time.time() - t0, 1e-9):.4f} it/s"
+ )
+ log_line(cfg.log_file, line)
+ return val, val_res, rho
+
+
+def save_ckpt(blk, cfg, step, best):
+ if not cfg.save:
+ return
+ path = resolve_path(cfg.save)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ torch.save(
+ {"allp": [p.detach().cpu() for p in blk.allp], "step": step, "best": best},
+ str(path) + ".tmp",
+ )
+ os.replace(str(path) + ".tmp", path)
+
+
+def main():
+ cfg = parse_args()
+ require_cuda(cfg.device)
+ dev = torch.device("cuda:0" if cfg.device == "cuda" else "cpu")
+ configure_globals(cfg, dev)
+ cfg.log_file = str(resolve_path(cfg.log_file))
+ Path(cfg.log_file).parent.mkdir(parents=True, exist_ok=True)
+ cfg_for_op = SimpleNamespace(**vars(cfg))
+
+ blk, ck, ckpt_path = build_block(cfg_for_op, dev)
+ opt, sched = make_optimizer_and_sched(blk, cfg_for_op)
+
+ torch.manual_seed(1234)
+ fixed_idx, _ = L.get_batch("val", cfg.rho_B, cfg.T)
+ torch.manual_seed(cfg.seed + 1)
+
+ header = (
+ f"# oracle_adjoint_train device={dev} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r} "
+ f"ckpt={ckpt_path} ckpt_step={ck.get('step')} ckpt_best={ck.get('best')} "
+ f"B={cfg.B} T={cfg.T} C={cfg.C} H={cfg.H} Mm={cfg.Mm} T1={cfg.T1} "
+ f"lr={cfg.lr} wd={cfg.wd} wsd={cfg.wsd} res_est={cfg.res_est} t1max={cfg.t1max}"
+ )
+ with open(cfg.log_file, "w", encoding="utf-8") as f:
+ f.write(header + "\n")
+ print(header, flush=True)
+
+ t0 = time.time()
+ idx0, y0 = L.get_batch("train", cfg.B, cfg.T)
+ grads0, info0 = oracle_grad(blk, idx0, y0, cfg_for_op)
+ sanity = info0["sanity_cos"]
+ sanity_line = (
+ f"step 0 sanity: cos(oracle_block_grad, asym_probe g_transpose)={sanity:+.6f} "
+ f"gmres_rel={info0['gmres_rel']:.3e} gmres_iters={info0['gmres_iters']} "
+ f"zstar_res={info0['zstar_res']:.3e} finite_T1_res={info0['finite_t1_res']:.3e}"
+ )
+ log_line(cfg.log_file, sanity_line)
+ if (not math.isfinite(sanity)) or sanity < cfg.sanity_cos_min:
+ bug = f"STOP: step-0 oracle/asym_probe sanity cosine {sanity:+.6f} < {cfg.sanity_cos_min:.6f}"
+ log_line(cfg.log_file, bug)
+ raise SystemExit(3)
+
+ info0["lr"] = sched.get_last_lr()[0]
+ best, _, _ = track(blk, cfg_for_op, fixed_idx, 0, info0, t0)
+ save_ckpt(blk, cfg_for_op, 0, best)
+
+ for step in range(1, cfg.steps + 1):
+ idx, y = L.get_batch("train", cfg.B, cfg.T)
+ try:
+ grads, info = oracle_grad(blk, idx, y, cfg_for_op)
+ except RuntimeError as err:
+ log_line(cfg.log_file, f"ABORT step {step}: {err}")
+ break
+ if info["finite_t1_res"] > cfg.abort_res:
+ log_line(
+ cfg.log_file,
+ f"ABORT step {step}: finite_T1_res {info['finite_t1_res']:.3e} > {cfg.abort_res:.3e}",
+ )
+ break
+ if not all((g is None) or torch.isfinite(g).all() for g in grads.values()):
+ log_line(cfg.log_file, f"ABORT step {step}: non-finite oracle gradient")
+ break
+
+ opt.zero_grad(set_to_none=True)
+ for p in blk.allp:
+ p.grad = grads.get(id(p))
+ torch.nn.utils.clip_grad_norm_(blk.allp, cfg.grad_clip)
+ opt.step()
+ apply_weight_caps(blk)
+ sched.step()
+
+ if step % cfg.log_every == 0:
+ info["lr"] = sched.get_last_lr()[0]
+ val, _, _ = track(blk, cfg_for_op, fixed_idx, step, info, t0)
+ best = min(best, val)
+ save_ckpt(blk, cfg_for_op, step, best)
+
+ save_ckpt(blk, cfg_for_op, step if cfg.steps > 0 else 0, best)
+ log_line(cfg.log_file, f"DONE best_val_CE={best:.4f}")
+
+
+if __name__ == "__main__":
+ main()