diff options
Diffstat (limited to 'ep_run/lt_ep_train.py')
| -rw-r--r-- | ep_run/lt_ep_train.py | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/ep_run/lt_ep_train.py b/ep_run/lt_ep_train.py index 4e7b8b1..99a1811 100644 --- a/ep_run/lt_ep_train.py +++ b/ep_run/lt_ep_train.py @@ -138,7 +138,8 @@ def ce(blk, z, y): def ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, - corr_every=1, res_gate=0.0, resreg=0.0, eigreg=0.0, eig_margin=1.0): + corr_every=1, res_gate=0.0, resreg=0.0, eigreg=0.0, eig_margin=1.0, + floss=0.0, floss_q=10, floss_rho=0.995, floss_bsub=4): xin0 = blk.embed(idx).detach() zs = relax(blk, xin0.clone(), xin0, T1, eps) res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) @@ -229,6 +230,30 @@ def ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=0, hr=0.02, t1max=0 for p, g in zip(blk.block, grr): if g is not None: grads[id(p)] = g * lam if grads.get(id(p)) is None else grads[id(p)] + lam * g + if floss > 0: # "floss-ept": GRADED finite-horizon LE penalty (aep 'floss' ported). + # Same fundamental quantity as resreg (path contraction rate = what EP-estimator validity needs; + # broadband, state-contamination-free) but a LINEAR early signal instead of resreg's rho^T1 cliff: + # unroll q steps from z_T1 on a sub-batch WITH graph, rho_hat = mean per-step growth of the update + # delta, one-sided penalty above rho target. Ramp keys on (rho_hat - target), NOT on resT1. + zb, xb = zT[:floss_bsub].detach(), xin0[:floss_bsub].detach() + with torch.enable_grad(): + zc, ds = zb, [] + for _ in range(floss_q): + d = eps * blk.tforce(zc, xb) # deterministic thick force, graph kept through the path + ds.append(d.pow(2).sum()) + zc = zc + d + rho_hat = (ds[-1] / (ds[0] + 1e-20)) ** (0.5 / max(1, floss_q - 1)) # (||d_q||/||d_1||)^(1/(q-1)) + blk._floss_rho = float(rho_hat.detach()) # logged by the train loop + Rf = torch.relu(rho_hat - floss_rho) ** 2 + gf = torch.autograd.grad(Rf, blk.block, allow_unused=True) if float(Rf) > 0 else None + if gf is not None: + ratio = floss * min(1.0, float(rho_hat - floss_rho) / 0.005) # graded: full strength 0.5% over target + gtask = math.sqrt(sum(float((grads[id(p)] ** 2).sum()) for p in blk.block if grads.get(id(p)) is not None) + 1e-20) + gfl = math.sqrt(sum(float((g ** 2).sum()) for g in gf if g is not None) + 1e-20) + lam = ratio * gtask / gfl # cap at `ratio` of the task-grad norm (resreg convention) + for p, g in zip(blk.block, gf): + if g is not None: + grads[id(p)] = g * lam if grads.get(id(p)) is None else grads[id(p)] + lam * g if eigreg > 0: # #2 v2: TRUE leading map-eigenvalue control (aep 'spectral', soft one-sided) from eig_control import spec_penalty # (omega/numerical-abscissa version refuted 2026-07-03, eig_recheck) ge, _rho, _mu = spec_penalty(blk, zs, eps, blk.c, eigreg, eig_margin, @@ -378,6 +403,10 @@ def main(): ap.add_argument('--resreg', type=float, default=0.0) # T1-residual penalty: defend z_T1 (cap ratio vs task grad); run res_gate=0 ap.add_argument('--eigreg', type=float, default=0.0) # #2 v2: soft penalty on TRUE |lam|_lead(I+eps*J_F) — aep 'spectral' at C512 ap.add_argument('--eig_margin', type=float, default=0.995) # rho target: penalize |lam|_lead above this (<1 = contracting relaxation map) + ap.add_argument('--floss', type=float, default=0.0) # floss-ept: graded finite-horizon LE penalty (de-cliffed resreg; FINDINGS 2026-07-03) + ap.add_argument('--floss_q', type=int, default=10) # unroll horizon (steps past z_T1, with graph, sub-batch) + ap.add_argument('--floss_rho', type=float, default=0.995) # per-step contraction target (one-sided; matches eig_margin) + ap.add_argument('--floss_bsub', type=int, default=4) # sub-batch rows for the graphed unroll (memory) ap.add_argument('--diag_cos', type=int, default=0) # #1: every N steps, log cos(EP grad, exact BPTT grad) + res ap.add_argument('--fingerprint', action='store_true') # load --init_ckpt, print (res,cos,abscissa,val) fingerprint, exit ap.add_argument('--opt', choices=['adamw', 'lion', 'lionlars', 'sgdm', 'sgdsai'], default='adamw') @@ -531,7 +560,7 @@ def main(): sw = hw_swap() if hw_on else None grads, res = ep_step(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, jr, cfg.holo, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel, cfg.corr_every, cfg.res_gate, cfg.resreg, - cfg.eigreg, cfg.eig_margin) + cfg.eigreg, cfg.eig_margin, cfg.floss, cfg.floss_q, cfg.floss_rho, cfg.floss_bsub) if sw is not None: hw_restore(sw) if cfg.jacreg > 0: # continuous controller: drive residual -> res_target (smooth) @@ -619,7 +648,8 @@ def main(): torch.save({'allp': [p.detach().cpu() for p in blk.allp], 'pema': [s.cpu() for s in pema] if pema is not None else None, 'step': step, 'best': best}, cfg.ckpt) - print(f"step {step:4d}/{cfg.steps} | val CE {v:.4f}{etag} (best {best:.4f}) | jr={jr:.1f} res={res:.1e} | {step/(time.time()-t0):.2f} it/s", flush=True) + ftag = f" rho={blk._floss_rho:.4f}" if cfg.floss > 0 and hasattr(blk, '_floss_rho') else "" + print(f"step {step:4d}/{cfg.steps} | val CE {v:.4f}{etag} (best {best:.4f}) | jr={jr:.1f} res={res:.1e}{ftag} | {step/(time.time()-t0):.2f} it/s", flush=True) save_state(step) # full-state checkpoint each log interval (Colab resume) print(f"[{cfg.mode}] DONE best val CE {best:.4f} (random baseline ln({vocab})={math.log(vocab):.3f})", flush=True) out_dir = Path('runs') |
