summaryrefslogtreecommitdiff
path: root/ep_run/lt_ep_train.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 18:21:21 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 18:21:21 -0500
commit6e78420da6e613964d93da06156b556e1a91caef (patch)
tree9329b0bc134fbc0627a80e6fd095651fcf9e4975 /ep_run/lt_ep_train.py
parentbcec9560cf5c9b113e9381a52d1a941daa8865f2 (diff)
floss-ept: graded finite-horizon LE penalty (--floss) + three-arm from-scratch ablation queueHEADmaster
- ep_step: floss block after resreg — unroll q=10 steps past z_T1 on a sub-batch WITH graph, rho_hat = mean per-step delta growth, one-sided relu(rho_hat - 0.995)^2, ramp keyed on (rho_hat - target) NOT resT1 (de-cliffed resreg: same fundamental path-LE quantity, linear early signal), capped at floss fraction of task-grad norm (resreg convention). - smoke: below-target = untouched (cos 1.0000); force-fire = finite grads, capped perturbation (cos 0.9803). - runs/abl3_queue.sh (runner live): waits for free GPU slots (0/1/3, GPU2 excluded), launches abl_floss (floss-only) / abl_resreg (resreg-only, never cleanly run) / abl_pair (proven 2.09 recipe, control) with identical remaining flags + seed. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/lt_ep_train.py')
-rw-r--r--ep_run/lt_ep_train.py36
1 files changed, 33 insertions, 3 deletions
diff --git a/ep_run/lt_ep_train.py b/ep_run/lt_ep_train.py
index 4e7b8b1..99a1811 100644
--- a/ep_run/lt_ep_train.py
+++ b/ep_run/lt_ep_train.py
@@ -138,7 +138,8 @@ def ce(blk, z, y):
def ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0,
- corr_every=1, res_gate=0.0, resreg=0.0, eigreg=0.0, eig_margin=1.0):
+ corr_every=1, res_gate=0.0, resreg=0.0, eigreg=0.0, eig_margin=1.0,
+ floss=0.0, floss_q=10, floss_rho=0.995, floss_bsub=4):
xin0 = blk.embed(idx).detach()
zs = relax(blk, xin0.clone(), xin0, T1, eps)
res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9)
@@ -229,6 +230,30 @@ def ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=0, hr=0.02, t1max=0
for p, g in zip(blk.block, grr):
if g is not None:
grads[id(p)] = g * lam if grads.get(id(p)) is None else grads[id(p)] + lam * g
+ if floss > 0: # "floss-ept": GRADED finite-horizon LE penalty (aep 'floss' ported).
+ # Same fundamental quantity as resreg (path contraction rate = what EP-estimator validity needs;
+ # broadband, state-contamination-free) but a LINEAR early signal instead of resreg's rho^T1 cliff:
+ # unroll q steps from z_T1 on a sub-batch WITH graph, rho_hat = mean per-step growth of the update
+ # delta, one-sided penalty above rho target. Ramp keys on (rho_hat - target), NOT on resT1.
+ zb, xb = zT[:floss_bsub].detach(), xin0[:floss_bsub].detach()
+ with torch.enable_grad():
+ zc, ds = zb, []
+ for _ in range(floss_q):
+ d = eps * blk.tforce(zc, xb) # deterministic thick force, graph kept through the path
+ ds.append(d.pow(2).sum())
+ zc = zc + d
+ rho_hat = (ds[-1] / (ds[0] + 1e-20)) ** (0.5 / max(1, floss_q - 1)) # (||d_q||/||d_1||)^(1/(q-1))
+ blk._floss_rho = float(rho_hat.detach()) # logged by the train loop
+ Rf = torch.relu(rho_hat - floss_rho) ** 2
+ gf = torch.autograd.grad(Rf, blk.block, allow_unused=True) if float(Rf) > 0 else None
+ if gf is not None:
+ ratio = floss * min(1.0, float(rho_hat - floss_rho) / 0.005) # graded: full strength 0.5% over target
+ gtask = math.sqrt(sum(float((grads[id(p)] ** 2).sum()) for p in blk.block if grads.get(id(p)) is not None) + 1e-20)
+ gfl = math.sqrt(sum(float((g ** 2).sum()) for g in gf if g is not None) + 1e-20)
+ lam = ratio * gtask / gfl # cap at `ratio` of the task-grad norm (resreg convention)
+ for p, g in zip(blk.block, gf):
+ if g is not None:
+ grads[id(p)] = g * lam if grads.get(id(p)) is None else grads[id(p)] + lam * g
if eigreg > 0: # #2 v2: TRUE leading map-eigenvalue control (aep 'spectral', soft one-sided)
from eig_control import spec_penalty # (omega/numerical-abscissa version refuted 2026-07-03, eig_recheck)
ge, _rho, _mu = spec_penalty(blk, zs, eps, blk.c, eigreg, eig_margin,
@@ -378,6 +403,10 @@ def main():
ap.add_argument('--resreg', type=float, default=0.0) # T1-residual penalty: defend z_T1 (cap ratio vs task grad); run res_gate=0
ap.add_argument('--eigreg', type=float, default=0.0) # #2 v2: soft penalty on TRUE |lam|_lead(I+eps*J_F) — aep 'spectral' at C512
ap.add_argument('--eig_margin', type=float, default=0.995) # rho target: penalize |lam|_lead above this (<1 = contracting relaxation map)
+ ap.add_argument('--floss', type=float, default=0.0) # floss-ept: graded finite-horizon LE penalty (de-cliffed resreg; FINDINGS 2026-07-03)
+ ap.add_argument('--floss_q', type=int, default=10) # unroll horizon (steps past z_T1, with graph, sub-batch)
+ ap.add_argument('--floss_rho', type=float, default=0.995) # per-step contraction target (one-sided; matches eig_margin)
+ ap.add_argument('--floss_bsub', type=int, default=4) # sub-batch rows for the graphed unroll (memory)
ap.add_argument('--diag_cos', type=int, default=0) # #1: every N steps, log cos(EP grad, exact BPTT grad) + res
ap.add_argument('--fingerprint', action='store_true') # load --init_ckpt, print (res,cos,abscissa,val) fingerprint, exit
ap.add_argument('--opt', choices=['adamw', 'lion', 'lionlars', 'sgdm', 'sgdsai'], default='adamw')
@@ -531,7 +560,7 @@ def main():
sw = hw_swap() if hw_on else None
grads, res = ep_step(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, jr, cfg.holo, cfg.hr,
cfg.t1max, cfg.res_est, cfg.t2sel, cfg.corr_every, cfg.res_gate, cfg.resreg,
- cfg.eigreg, cfg.eig_margin)
+ cfg.eigreg, cfg.eig_margin, cfg.floss, cfg.floss_q, cfg.floss_rho, cfg.floss_bsub)
if sw is not None:
hw_restore(sw)
if cfg.jacreg > 0: # continuous controller: drive residual -> res_target (smooth)
@@ -619,7 +648,8 @@ def main():
torch.save({'allp': [p.detach().cpu() for p in blk.allp],
'pema': [s.cpu() for s in pema] if pema is not None else None,
'step': step, 'best': best}, cfg.ckpt)
- print(f"step {step:4d}/{cfg.steps} | val CE {v:.4f}{etag} (best {best:.4f}) | jr={jr:.1f} res={res:.1e} | {step/(time.time()-t0):.2f} it/s", flush=True)
+ ftag = f" rho={blk._floss_rho:.4f}" if cfg.floss > 0 and hasattr(blk, '_floss_rho') else ""
+ print(f"step {step:4d}/{cfg.steps} | val CE {v:.4f}{etag} (best {best:.4f}) | jr={jr:.1f} res={res:.1e}{ftag} | {step/(time.time()-t0):.2f} it/s", flush=True)
save_state(step) # full-state checkpoint each log interval (Colab resume)
print(f"[{cfg.mode}] DONE best val CE {best:.4f} (random baseline ln({vocab})={math.log(vocab):.3f})", flush=True)
out_dir = Path('runs')