diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run')
110 files changed, 16378 insertions, 0 deletions
diff --git a/ep_run/CODEX_VERDICT.md b/ep_run/CODEX_VERDICT.md new file mode 100644 index 0000000..7d29c39 --- /dev/null +++ b/ep_run/CODEX_VERDICT.md @@ -0,0 +1,151 @@ +# CODEX VERDICT: EP below-2.10 divergence + +## Ruling + +Verdict: (b) STRUCTURAL. + +Converging the EP adjoint phase is necessary for a correct equilibrium-gradient estimate, but it is not sufficient to make this training problem behave like BPTT. The exact equilibrium gradient differentiates `L(z*)`. BPTT differentiates the actual deployed computation `L(z_T)` with `T=150`. Those are different objectives whenever convergence is not effectively complete. The missing term is the finite-horizon residual/contraction term. No `t2sel` or `hr` knob can add that term to the exact fixed-point gradient. + +## Fact-set check + +1. Correct in substance. `runs/bptt_clean.log` reaches best val CE `1.8277` and keeps the finite-`T1` residual small, around `4e-4` to `1e-3` late in training. The failure is EP-specific in the comparable EP logs. The `rho~0.982` value is referenced by the rho prober header and scripts, but the requested log set does not contain the full `spec_bifurcation.py` output, so the CE/residual part is directly verified and the exact rho number is not independently reprinted in the available logs. + +2. Correct. `runs/t2_sweep.log` shows `cos(g_EP,g_transpose)` rising from `0.742413` at `t2sel=10` to `0.998194` at `t2sel=160`. `runs/hr_ceiling_sweep.log` shows the remaining gap flat across `hr=0.04..0.8` at about `0.94..0.946`. That is adjoint-phase truncation, not beta-radius tuning. The code agrees: `lt_ep_train.py` calls `holo_a_track`/`holo_a_select2` with a fixed `t2sel`, and `holo_ep.py` selects a finite nudged snapshot rather than enforcing an adjoint residual. + +3. Correct. The same sweep reports `cos(g_transpose,g_BPTT)=0.974976` at free-phase step residual `2.17e-5`. `diag_probe.log` shows the exact fixed-point regime, residual around `1e-9`, where `cos(g_transpose,g_BPTT)=1.000000`. This is the finite-horizon/free-residual gap. + +4. Correct in causal direction, with one evidence caveat. `ep_redx.log` shows the sudden event: CE goes from `2.7417` at step 3200 with residual `2.5e-2` to CE `41` with residual `1.1e-1` at step 3300. `redx_traj.log` shows EP gradient quality degrading and residuals growing along the approach. The rho drift and damping-resistant `c` sweep are encoded in `spec_bifurcation.py`, `spec_rho_vs_c.py`, and cited in `t2fix_rho.log`, but the actual c-sweep output is not present in the requested logs. The important point remains: the finite-`T1` residual becomes hypersensitive near marginality. + +5. Correct as the root cause, but not yet empirically closed for `t2sel=160`. `runs/t2fix_rho.log` only has steps 100, 200, 300 at CE about `6`, with `rho~0.794`; it does not prove that `t2sel=160` will pass or fail near CE 2.x. The math below decides the open question: exact equilibrium gradients still optimize the wrong objective for finite-time deployment. + +## Why the exact equilibrium gradient lacks contraction defense + +Let the relaxation map be + +```text +Phi_theta(z) = z + eps F_theta(z) +``` + +and let `z*` satisfy `F_theta(z*) = 0`. The equilibrium objective is + +```text +J_inf(theta) = L(z*(theta)). +``` + +Differentiating the fixed-point equation gives + +```text +F_z dz*/dtheta + F_theta = 0 +dz*/dtheta = -F_z^{-1} F_theta +``` + +Equivalently, solve the equilibrium adjoint + +```text +F_z^T lambda = -L_z(z*) +grad_theta J_inf = L_theta + lambda^T F_theta. +``` + +That is exactly what the EP/AEP estimator is trying to approximate. It contains `F_z^{-1}`, so slow modes can amplify loss sensitivity. But it contains no term for the finite relaxation length, no `T`, no initial residual, no `Phi^T`, and no derivative of `rho(Phi_z)` unless changing that contraction also changes `z*` or the equilibrium loss. A parameter that changes the convergence rate while leaving the fixed point and readout loss unchanged has zero exact equilibrium gradient. + +Scalar counterexample: + +```text +F_k(z) = -k (z - z*) +Phi_k(z) = z + eps F_k(z) +L = L(z*) +``` + +For any positive `k`, the fixed point is the same. Therefore + +```text +d L(z*) / d k = 0. +``` + +But the finite state is + +```text +z_T - z* = (1 - eps k)^T (z_0 - z*), +``` + +so + +```text +d L(z_T) / d k +``` + +contains a term proportional to + +```text +T (1 - eps k)^(T-1). +``` + +That is exactly the contraction-defense term. It is large near `rho=1`, precisely where `rho^150` becomes explosive. It vanishes only in the true infinite-time limit when `rho<1` with enough margin. + +For the full model, BPTT differentiates + +```text +z_{t+1} = Phi_theta(z_t) +grad_theta L(z_T) + = L_z(z_T)^T sum_{k=0}^{T-1} + (prod_{s=k+1}^{T-1} Phi_z(z_s)) eps F_theta(z_k) + + direct terms. +``` + +Those products are the same objects that determine finite-time contraction. When they decay slowly, the finite-horizon gradient feels it. Equilibrium EP replaces this whole finite product chain with the fixed-point inverse at `z*` and takes `T=infinity`; the transient residual term is gone. + +The code implements this split exactly. In `lt_ep_train.py`, `ep_step` relaxes to `T1`, optionally refines beyond `T1`, and computes the task gradient at `zs` through `(a * f).sum()`. `bptt_step` unrolls exactly `T1` steps and differentiates `ce(blk, z, y)` at the final unrolled state. `evaluate()` also uses exactly `T1` relaxation steps. Therefore BPTT is optimizing the evaluated computation and EP is optimizing the refined fixed-point computation. + +## Consequence + +The exact equilibrium adjoint can be correct and still push into a marginal operator, because the equilibrium objective is indifferent to settling time except through its effect on `z*`. The evidence that EP can reach `cos(g_EP,g_transpose)=0.998` only proves that EP can compute the fixed-point gradient. It does not prove that the fixed-point gradient contains BPTT's finite-horizon stabilizer. It does not. + +So the fix is not "set `t2sel=160` and call the adjoint converged." That removes one estimator error. It does not change the target objective. If the deployed model is `T1=150`, the training signal must include finite-horizon dynamics or an explicit contraction objective. + +## Local forward-only fix + +This is fundamental for pure equilibrium EP on `L(z*)`, but not fundamental for local forward-only learning if the objective is changed. + +Concrete construction: finite-horizon forward-mode/RTRL eligibility training for `L(z_T)` plus, if needed, a local contraction penalty. + +Run the physical relaxation forward for `T=150`. Alongside the state, propagate local eligibility traces: + +```text +e_{t+1}^{(p)} = Phi_z(z_t) e_t^{(p)} + eps dF_theta(z_t)/dp +``` + +At `T`, form the local three-factor update + +```text +Delta p proportional to - L_z(z_T)^T e_T^{(p)}. +``` + +This is forward-mode differentiation of the actual finite unroll. It is not reverse BPTT, and it is not a digital root finder. Exact per-parameter RTRL is expensive; practical versions use blockwise, low-rank, or random-direction eligibility traces. But this is the correct class of construction because it preserves the finite product terms that defend contraction. + +If hardware or cost makes forward-mode eligibility too expensive, the alternative local objective is an explicit contraction homeostat: + +```text +R_contr = E_v sum_t log( ||Phi_z(z_t) v_t|| / ||v_t|| ) +``` + +estimated with two nearby physical trajectories or JVP hardware, or a hard monotone/contractive parameterization enforcing a negative log norm. This changes the objective/architecture. It is a valid local fix, but it is not "better EP gradient quality." + +## Single decisive experiment + +Run oracle exact-equilibrium-adjoint training, not merely deeper EP, from the same pre-drift checkpoint and with `resreg=0`. + +At every update: + +1. Relax/refine to `z*`. +2. Solve the exact adjoint `F_z(z*)^T lambda = -L_z(z*)` by GMRES or an equivalent oracle. +3. Apply `grad_theta = L_theta + lambda^T F_theta`. +4. Track `cos(oracle,g_transpose)`, finite-`T1` residual, and free-phase rho on the fixed validation batch every 100 steps. + +Decision rule: + +```text +If oracle equilibrium-adjoint training keeps rho near the BPTT value and clears the wall, (a). +If it still drifts rho toward 1 and blows while cos(oracle,g_transpose) is near 1, (b). +``` + +My ruling is that the second outcome will occur. The exact equilibrium gradient is the wrong gradient for the finite-150-step computation; it cannot contain the missing finite-horizon contraction-defense term by construction. diff --git a/ep_run/EP_DIAGNOSIS_DOSSIER.md b/ep_run/EP_DIAGNOSIS_DOSSIER.md new file mode 100644 index 0000000..22308fa --- /dev/null +++ b/ep_run/EP_DIAGNOSIS_DOSSIER.md @@ -0,0 +1,99 @@ +# EP below-CE-2.1 divergence — complete diagnosis dossier (2026-06-22, CORRECTED) + +## Setup +Equilibrium transformer block: fixed point of a damped relaxation `z ← z + ε·F(z)`, ε=0.1, +`F(z) = x_in − (1+c)z + Attn(LN z) + FFN(LN z)`, c=1. Attention is **non-conservative** (independent +WQ,WK,WV,WO; qknorm RMSNorms q,k). Untied 4× GELU FFN. Trained **backprop-free** with **AsymEP** +(Scurria 2602.03670: nudged dynamics get `−2A_J(x⁰)(x−x⁰)`, making the nudged Jacobian = Jᵀ at the +free equilibrium). Code: `lt_ep_train.py` (`force`/`tforce`:81-106, `relax`:123, `ep_step`:140), +`holo_ep.py` (holomorphic estimator). Eval/BPTT use the T1=150 relaxed state. + +**SYMPTOM:** EP descends, then **suddenly** diverges below CE≈2.1 (e.g. val 2.74 → 41 within ~100 steps, +T1-residual 2.5e-2 → 0.42). Exact **BPTT on the identical model trains cleanly to CE 1.83.** + +## CORRECTED diagnosis (measured this round — supersedes earlier framings) + +**Fact 1 — it is a forward LIMIT CYCLE, there is no fixed point at the diverging operator.** +`eval_relax` on the marginal ckpt redx **s3200** (val 2.74, just before the blowup): relax from the +embedding for **6000 steps** → +`res(t): 50→3.6e-2, 150→2.3e-2, 500→2.5e-2, 1000→2.5e-2, 2000→2.3e-2, 4000→2.3e-2, 6000→2.4e-2`, +tail(last 1000) min 2.08e-2 / max 2.73e-2, **non-monotone**. It **floors ~2.3e-2 and oscillates — never +reaches a fixed point.** (Reproduces an earlier lost-run finding: limit cycle, FTLE<0.) + +**Fact 2 — the cycle is driven by the non-conservative ATTENTION.** +Knockout: scale the attention output (`WO ×= α`), eval_relax 3000 steps: +``` +α=1.0: res-floor 2.5e-2, osc 6.0e-3 CYCLE +α=0.7: res-floor 1.6e-2, osc 3.0e-3 CYCLE (smaller) +α=0.4: res-floor 4.1e-3, osc 5.3e-4 nearly gone +α=0.2: res-floor 3.2e-4 CONVERGED (true fixed point restored) +α=0.0: res-floor 1.3e-3, osc 1.2e-3 tiny FFN-only cycle +``` +Reducing the attention monotonically shrinks the cycle and restores convergence. **The attention's +non-conservativity drives the limit cycle.** + +**Fact 3 — hypothesis: a Hopf-type bifurcation.** A relaxation `z←z+εF(z)` (map `M=I+εJ`) can only +*oscillate* (limit cycle) if `M` has a **complex eigenvalue pair crossing |λ|=1**. A symmetric/conservative +J has real eigenvalues → monotone convergence or blow-up, never a cycle. As EP training grows the attention +asymmetry/gain, a complex pair crosses → Hopf → limit cycle → readout of a cycle-point degrades → CE explodes. + +## RETRACTED framings (do not anchor on these) +- codex's "(b) structural: equilibrium gradient L(z\*) blind to contraction → forward-mode/RTRL fix" + — **assumed a fixed point z\* exists.** It does not at the diverging operator (limit cycle). The scalar + counterexample (param changing convergence rate but not z\*) is moot when z\* doesn't exist. +- "ρ drifts to 0.998 / slow convergence" — was a **transient artifact** of a ρ-probe window (caught the + initial 3.6e-2→2.3e-2 decay, missed the floor+oscillation). + +## Still-valid facts (about the GRADIENT estimator — separate axis from the forward cycle) +- BPTT (exact grad) → CE 1.83, converges; its trajectory does NOT drive the attention into the cycling regime. +- AsymEP gradient is accurate WHEN a converged fixed point exists: cos(g_EP, exact-adjoint)=0.99 at hr=0.2, + res 1e-9; the "0.94 ceiling" was nudged/adjoint-phase truncation (cos rises 0.74→0.998 as nudge-depth + t2sel 10→160). i.e. the estimator is fine *given a fixed point* — but at the diverging state there is none. + +## AEP paper (Scurria 2602.03670) context +- AsymEP is exact AT the **stationary state** (needs convergence). Appendix G.3 explicitly treats the + **stability of non-conservative dynamics** — they can **oscillate** — controlled by the **asymmetry ratio + r_str** (Eq 37-38: `J = γ(√(1−r_str²)·S̃/‖S̃‖ + r_str·Ã/‖Ã‖)`) + **gain γ** + conservative init + `Var[J]∝1/N`. "AsymEP reduces oscillations." + +--- + +## Q1 (THIS query) — CONFIRM THE MECHANISM +Is the divergence a **non-conservative Hopf bifurcation**: the attention's antisymmetric part A drives a +**complex conjugate eigenvalue pair of the relaxation map M=I+εJ across |λ|=1**, producing the forward limit +cycle (Facts 1-3)? +1. Is the evidence (limit cycle in Fact 1 + the attention-scaling knockout in Fact 2) **conclusive** for a + Hopf bifurcation, or what is the gap / what alternative (e.g. real-eigenvalue saddle-node, a 2-cycle from + the discrete Euler step εF, an FFN contribution) is not yet excluded? +2. What is the **single cleanest measurement** to nail it — e.g. compute the eigenvalues of `M=I+εJ` at + s3200 (is there a complex pair with |λ|≥1, vs a real λ≥1)? a Floquet/period analysis of the cycle? an + ε-sweep (does shrinking ε convert the cycle to convergence — distinguishing a continuous-time vs + discrete-Euler instability)? +3. Verify the mechanism against the actual `lt_ep_train.py` force/relax code. + +## Q2 — THE FIX +Given Q1 (a Hopf bifurcation from the attention's non-conservativity), what is the best way to keep the +operator **below** the bifurcation (so a fixed point exists and AsymEP is valid) while preserving as much +attention expressivity as possible? Candidates: (a) **adaptive asymmetry penalty** (our `jacreg` penalizes +‖J_nc‖≈‖A‖, ramped on the residual/cycle onset; the validated 2.40 runs used this, the diverging runs froze +it weak); (b) **structural r_str-style parameterization** (bound the antisymmetric part by construction, +paper Eq 38); (c) **gain control** (γ scaling / qknorm — bound the spectral gain); (d) a **direct +cycle-amplitude / log-norm μ_P(J) penalty**. Which is most effective AND analog-realizable (forward-only, +local)? Give a concrete recipe. + +## Q3 — THE THESIS +Can a non-conservative attention stay **sub-Hopf** (no limit cycle) AND be expressive enough for coherent +language, or is there a **fundamental expressivity-vs-stability tradeoff** (the expressivity needs +asymmetry/gain that triggers the bifurcation)? Estimate the bifurcation threshold (in r_str/γ terms) for this +architecture and whether the sub-threshold regime suffices for an LM. Is a hybrid (bounded-asymmetry core + +thin correction) the realistic ceiling? + +## Q4 — EQUILIBRIUM vs NON-EQUILIBRIUM PRIMITIVE +AsymEP requires a **stationary state**, which does not exist in the limit cycle. Two routes: (i) keep the +operator below the Hopf (fixed point exists → AsymEP exact), accepting the expressivity bound; (ii) **embrace +the non-equilibrium** (limit-cycle) computation with a learning rule native to it (oscillatory / reservoir / +time-averaged). Which is the right primitive for analog hardware, and is (ii) even tractable with a local +forward-only rule? + +--- +Answer **Q1 → Q2 → Q3 → Q4 in order**, each rigorously and grounded in the code/data. Be decisive. diff --git a/ep_run/FUGU_OPTIONS_VERDICT.md b/ep_run/FUGU_OPTIONS_VERDICT.md new file mode 100644 index 0000000..4e3ed25 --- /dev/null +++ b/ep_run/FUGU_OPTIONS_VERDICT.md @@ -0,0 +1,263 @@ +# FUGU_OPTIONS_VERDICT — Q1–Q3 (independently verified) + +Scope: answers grounded in `lt_ep_train.py` (`force`/`tforce` :81-106, `relax` :123-133, +`ep_step` :140-232, `jacreg` :211-219, weight caps :52-53/398-399/563-567), `holo_ep.py`, +the calibration probes (`adaptive_eps_calib.py`, `adaptive_eps_calib2.py`, `eps_sweep_s3200.py`, +`jnc_scaling.py`, `lt_ep_anderson.py`), and the diagnosis dossiers. Each claim is flagged +**[SOLID]** (proved by code/data in repo) or **[UNCERTAIN]** (reasoned, not measured here). + +--- + +## Shared mechanism (the object all three questions act on) + +**[SOLID]** The active free relaxation is explicit (forward) Euler: +`z = z + eps * blk.force(z, xin).detach()` (`relax`, :123-133). In thick mode the force is +`F(z) = -(z - xin) + Attn(LN1 z) + FFN(LN2 z) - c*z` (`tforce`/`force` :81-85, :102-106), c=1. +So the per-step linear stability object is the **discrete map** `M = I + eps*J`, `J = dF/dz`. + +**[SOLID]** For a continuous eigenvalue `mu = a + i b` of `J`, the Euler multiplier is +`lambda = 1 + eps*mu`, and the map is stable iff `|1+eps*mu| < 1`, i.e. +`eps < eps_crit = -2a/(a^2 + b^2)` for `a < 0`. A continuous-STABLE rotating mode (`a<0`, `b` large) +is destabilized purely by too-large `eps`. + +**[SOLID]** The ε-monotonicity training data are decisive that this is an *integration* wall, not a +*gradient-quality* wall: eps=0.1 blew @ CE 2.74; eps=0.1 with a strictly better gradient (t2sel=160, +cos 0.998) blew EARLIER @ 3.02; eps=0.05 reached 2.41 before blowing. Better gradient → not later but +earlier; smaller step → strictly lower wall. That is exactly the `|1+eps*mu|>1` signature. + +### One correction to the dossier's "continuous/analog is stable at s3200" framing +**[SOLID — verified, refines prior verdict]** The eps-sweep "CONVERGED at eps=0.01" is measured with a +*different residual* than the cycle floor. `eps_sweep_s3200.py:17` reports the **step** residual +`r = ‖z2-z‖/‖z‖ = eps·‖F‖/‖z‖`; `adaptive_eps_calib.py:15` reports the **force** residual +`g = ‖F‖/‖z‖`. At eps=0.01 the sweep's `r≈8.9e-4` is just `0.01 × 0.089` — i.e. the *same* force-floor +`g≈0.09` that is called a "cycle" at eps=0.1. `FUGU_Q_OPTIONS.md` itself flags this: +"s3200 g floors ~0.09 even at tiny ε (genuinely no fixed point at the marginal op, OR just slow +finite-step convergence — ambiguous)." +**Implication:** the eps-sweep robustly proves *the oscillation/blow-up is a discrete-Euler artifact* +(the cycle amplitude dies as eps→0). It does **not** by itself prove the s3200 operator has a true +attracting fixed point (g→0) in continuous time — the force floor g≈0.09 persists. The clean +continuous-stable case is s2000 (g→0). So "analog HW would have no problem" is **[SOLID]** for the +*oscillatory blow-up* but **[UNCERTAIN]** for "s3200 settles to a usable equilibrium." The decisive +missing measurement remains the leading eigenpair of `J`/`M` at a continued fixed-point branch +(sign of `Re mu`). + +--- + +## Q1 — Evaluate (a) adaptive ε, (b) jacreg, (c) smaller fixed ε + +**Bottom line:** +- **(c) smaller fixed ε — RELOCATES the wall. [SOLID]** Already shown empirically (2.74→2.41). +- **(b) jacreg — RAISES/RELOCATES the wall from the model side. [SOLID it raises eps_crit; UNCERTAIN whether it can eliminate]** It lifts `eps_crit` by cutting `|Im mu|`/gain, but at fixed ε it is still a wall in `eps_crit`-space; it also taxes the expressivity it suppresses. +- **(a) adaptive ε — ELIMINATES the fixed-ε wall *iff* its floor stays below the instantaneous `eps_crit`; otherwise it degenerates to (c). [SOLID for the mechanism; the guarantee is conditional]** + +### Ranking +**To remove the measured software wall while preserving the model and the analog target:** +1. **Adaptive ε / robust solver** — only option that removes the *fixed-step* wall with **zero model/expressivity cost** and **zero change to the analog target**. It is pure integration-axis. +2. **jacreg** — effective secondary homeostat; raises `eps_crit`, but changes the learned operator and can cap the non-normality the good (BPTT-1.83) solution uses. +3. **smaller fixed ε** — diagnostic/fallback only; permanently pays the small-step cost on *every* example (including smooth ones) and still fails once stiffening passes the new floor. + +**For the analog (continuous) target specifically:** adaptive ε and smaller fixed ε are *emulator* +choices that leave the model identical to what analog HW runs — they are the right kind of fix. +jacreg *changes the model that analog HW would run* (see Q2). + +### (a) Adaptive ε — grounded in code +**[SOLID]** `adaptive_eps_calib2.py` uses the correct signal: shrink only on **overshoot** +(`g_t > prev*tol` → `eps*=down`), grow otherwise. The naive `adaptive_eps_calib.py` controller +(shrink on slow contraction) is shown to mis-park ε at the floor on all ops — it conflates small-ε's +slow contraction with instability. The corrected controller behaves as a continuous-relaxation +emulator: stiff s3200 → ε to 0.003-0.008; smooth s2000 → ε grows toward 0.1 and reaches g=0. + +### Is adaptive ε *guaranteed* to eliminate the wall? — the eps_min question +**[SOLID, decisive]** No, not unconditionally. With a hard floor `eps_min`, adaptive ε eliminates the +wall only while `eps_min < eps_crit = -2a/(a^2+b^2)`. If training keeps stiffening the rotating mode so +`eps_crit` falls below `eps_min`, adaptive ε becomes a fixed small step at the floor — i.e. it +**degenerates into option (c) and merely relocates the wall.** So the guarantee is conditional on the +floor, and equivalently on whether `eps_crit` (hence `|Im mu|`) is bounded away from where the floor +sits. + +### Does |Im μ| (b) saturate or grow unboundedly as CE drops? +This is the crux, and the honest answer is split: + +- **[SOLID] There IS structural stiffness-bounding machinery in the code that argues for saturation.** + (i) `qknorm` RMSNorms q,k → softmax logits are bounded regardless of weight growth (`attn` :63-67); + (ii) **weight-norm caps**: `capw = {WQ,WK,WV,WO,Wm,Wh,fc,pj}` are each projected back to + `capx × initial-norm` every optimizer step (`:52-53`, `:398-399`, `:563-567`); (iii) damping `c=1` + gives a passive `-(1+c)z = xin-2z` contraction floor; (iv) LayerNorm bounds input scale into attn/FFN; + (v) weight decay. With qknorm + capped projections + LN, the per-matrix gains feeding `J_nc` cannot + grow without bound, which bounds `|Im mu|` and therefore keeps `eps_crit` bounded **below**. This is a + genuine reason to expect `|Im mu|` to **saturate** (or at least be bounded) rather than diverge. + +- **[SOLID, opposing data point] But within the *observed* range stiffness was still rising:** fixed + ε=0.1→0.05 moved the wall 2.74→2.41 rather than removing it, i.e. `eps_crit` was still falling across + that CE interval. So saturation, if it exists, had not yet bitten in the measured window. + +- **[UNCERTAIN] No direct eigenvalue/`|Im mu|`-vs-CE trace exists in the repo.** `jnc_scaling.py` + measures `‖J_nc‖` growth-per-step vs width but is not a CE-resolved `|Im mu|` curve. So whether `b` + truly plateaus before `eps_crit` reaches a practical `eps_min` is **not measured**. + +**Synthesis (decisive, hedged correctly):** adaptive ε is the best wall-eliminator and the only +zero-tax, analog-faithful one — **and** the code's caps/qknorm/damping make it *likely* that `|Im mu|` +is bounded, so a sufficiently small `eps_min` should eliminate (not merely relocate) the wall in +practice. But this is a *bounded-floor* guarantee, not an unconditional one: if `|Im mu|` were to grow +without bound, any finite `eps_min` is eventually a wall. **Recommended:** make the floor itself +log an `eps_crit` proxy (overshoot persisting at the floor) and either drop the floor, reject the step, +or hand off to Anderson — i.e. fail-open rather than fail-into-(c). + +--- + +## Q2 — The jacreg paradox + +**Verdict: no paradox. jacreg works by RAISING `eps_crit` from the model side — it fixes the SAME +discretization wall, not a demonstrated continuous-time instability. Relative to adaptive ε it is a +sim-crutch for the measured failure, but it carries a *separate, real* analog benefit (settling +quality), and it would become a genuine fix if a true continuous instability (Re μ≥0) ever emerges.** + +### Why a model-side stiffness penalty fixes a simulation artifact — mechanism +**[SOLID]** `jacreg` is a Hutchinson JVP penalty `R = jacreg·‖J_nc·er‖²/‖er‖²` (`:211-219`), and in thick +mode `nc_force = Attn + FFN` (`:92-97`). Minimizing `‖J_nc‖` reduces the learned non-conservative +gain, which reduces the rotating component `|b|=|Im mu|` (and non-normal amplification). Since +`eps_crit = -2a/(a^2+b^2)`, smaller `|b|` → **larger** `eps_crit` → fixed ε=0.1 stays under the +Euler-stability boundary longer. So jacreg moves the *same* `|1+eps*mu|=1` wall by shrinking `b`, while +adaptive ε moves the *same* wall by shrinking `eps`. Two knobs on one inequality. + +### Raising eps_crit vs fixing a continuous-time problem +**[SOLID for measured regime]** For s3200-type failures the relevant mode has `Re mu < 0` (the cycle +dies as eps→0). There is no *established* continuous instability to fix, so jacreg's contribution there +is purely "raise eps_crit" — discretization-wall relief from the model side. +**[UNCERTAIN beyond it]** If training ever drives `Re mu → 0⁺` (a true Hopf), then no integrator +(adaptive ε, implicit, Anderson) can stabilize the original continuous equilibrium; only a model-side +change (jacreg, stronger damping/c, structural monotonicity, gain/asymmetry bounds) is a real fix. +jacreg is the insurance policy for that case. + +### Does the benefit transfer to analog hardware? — two benefits, separated +**[SOLID] (i) The "prevents eps=0.1 Euler blow-up" benefit does NOT transfer.** Analog HW has no `eps` +and does not iterate `z←z+εF`; it performs continuous relaxation. If `Re mu<0`, the analog ODE is +stable and never had this wall. To the extent jacreg only buys eps_crit headroom, it is papering over a +sim artifact analog wouldn't have — a crutch. + +**[SOLID/UNCERTAIN-magnitude] (ii) The "less stiff/less ringy continuous dynamics" benefit DOES +transfer.** Even with `Re mu<0`, a large `|Im mu|` mode has a poor damping ratio: it rings, settles +slowly, demands more bandwidth, longer observation/integration windows, and is more noise/delay +sensitive — all of which degrade the *physical* free-phase settling and the readout of nudged +equilibria on analog HW. Reducing `‖J_nc‖` improves the continuous damping ratio. So jacreg is *also* a +legitimate analog settling/robustness regularizer. **[UNCERTAIN]** the size of this analog benefit is +not measured here. + +### Real fix or sim-crutch, relative to adaptive ε? +**[SOLID]** For the *confirmed explicit-Euler artifact*: +- **adaptive ε / Anderson / implicit = the real fix of the emulator** — they preserve the learned + vector field and make the digital sim stop inventing a cycle the analog system wouldn't have. +- **jacreg = a model-changing crutch for that artifact**, but simultaneously a *real* (if secondary) + analog settling regularizer and the *only* lever if a genuine continuous instability appears. + +**Recommended composition (not "either/or"):** (1) use adaptive ε / a real solver as the primary +emulator fix so the sim is faithful; (2) keep jacreg as a **bounded, adaptive** homeostat +(the controller already exists, `:520-529`) sized for analog settling-time/robustness or true +marginality — NOT as a strong fixed penalty that taxes the non-normality the BPTT-1.83 solution needs. +The historical evidence fits this: the validated ~2.40 runs used *adaptive* jacreg; the diverging runs +*froze it weak* — i.e. they removed the homeostat, not the integrator. + +--- + +## Q3 — Anderson acceleration / implicit (IMEX) integrators + +**Verdict: Yes — they can replace explicit Euler as the *solver* and kill the discretization +instability, and they are compatible with AsymEP *provided they converge to the same equilibria of the +same vector fields*. They change nothing about the analog model; they are emulator-fidelity choices. +Implicit Euler is unconditionally stable but per-step expensive (the solve is itself a relaxation). +Anderson is the more practical lever: it both accelerates and can suppress the Euler cycle when a true +fixed point exists, but it is not guaranteed and needs damping/restarts/residual gating.** + +### (i) Compatibility with AsymEP +**[SOLID]** The EP estimator depends on the *states*, not on how they were reached. `ep_step` computes +`zs = relax(...)` and treats it as the free equilibrium (`:142-144`); the AsymEP correction uses local +`Jv = jvp(nc_force, zs, v)`, `JTv = vjp(nc_force, zs, v)`, `corr = Jv - JTv` at `zs` (`:172-178`); the +parameter gradient is `(a * f).sum()` with `f = force(zs.detach(), xin, cg=True)` (`:202-205`). None of +this requires explicit Euler — it requires that `zs` is a genuine root `F(zs)≈0` and that the nudged +states are equilibria of the nudged/corrected force. A better solver that returns the *same roots* is +fully compatible, and the `-2A` correction is computed *at* `z*` regardless of the solver that found it. + +**[SOLID — important, refines prior framing] The nudged phase must also be re-solved.** The free phase +is not the only explicit-Euler loop: the nudge (`nudge()` :163-180) and every holomorphic estimator +(`holo_a`, `holo_a_select2`, `holo_a_track`, `holo_a_lockin` in `holo_ep.py`) advance with +`z = z + eps*(f - corr)`. The `-2A` correction lives *inside* these loops. So "swap the integrator" +means swap it in **both** phases; a solver that converges the free `z*` but leaves the nudged phase on +coarse Euler will still corrupt `a = -dz*/dβ`. + +**[SOLID] Hard limit:** if the continuous field has no attracting root in the operating regime, no +solver can manufacture the stationary state AsymEP needs — it will fail, find a spurious root, or +return a numerical artifact. A solver fixes *integration*, not *non-existence of equilibrium*. (This is +why the s3200 force-floor ambiguity from the Shared-mechanism section matters: confirm a true root +exists before trusting AsymEP there.) + +### (ii) Implicit / IMEX — tractable or self-defeating? +**[SOLID, theory]** Backward Euler multiplier is `1/(1-h·mu)`, A-stable: for any `Re mu<0` it is stable +at *every* step size, so it would kill the stiff-rotation Euler cycle outright. +**[SOLID, cost]** Each backward step solves `y - h·F(y) - z_n = 0`, where `F` contains LN, causal +softmax attention, and FFN. A Newton/Krylov/Picard solve needs several force evals and matrix-free +JVP/VJP linear solves over the full `B·T·C` state per step — i.e. **the per-step solve is itself a +relaxation/root-find**, which is the self-defeating risk for a default inner loop. +**[UNCERTAIN/qualitative] IMEX nuance:** making only the cheap leak `-(1+c)z` implicit is trivial but +does **not** tame the dangerous learned rotating attention mode (the danger is in `J_nc`, not the leak); +treating `J_nc` implicitly reintroduces the big linear solve. So implicit/IMEX is best as a **robust +fallback / macro-step / offline reference**, not the default per-step integrator. + +### (iii) Anderson — speed only, or stabilization too? +**[SOLID, conceptual]** Anderson (DEQ-style; `lt_ep_anderson.py` stores recent `X`, `G(X)=z+εF`, solves +a small regularized least-squares for the mixing coefficients, extrapolates) is a quasi-Newton/GMRES-on- +the-residual. For a Picard/Euler map whose oscillatory multiplier sits just outside the unit circle, +the residual-minimizing extrapolation can **suppress the limit cycle**, not merely speed a contracting +one — so it is more than acceleration. `lt_ep_anderson.py` is explicitly framed as exactly this test +("can a fixed-point solver converge the free phase where plain relaxation cannot?"). +**[SOLID, caveats]** Not guaranteed: it cannot create a root that doesn't exist; aggressive mixing can +diverge; it needs damping (β-mixing), restarts, and residual-monotonicity gating; and (per (i)) it must +wrap the nudged phase too. Net: **strongest practical candidate** — cheaper than full implicit Newton, +able to stabilize when a root exists, but must be safeguarded. + +### (iv) Does integrator choice matter for the ANALOG target? +**[SOLID] For the analog model itself: no.** Analog HW performs the true continuous relaxation of `F`; +it runs no explicit Euler, no Anderson, no backward Euler. The integrator is not part of the deployed +computation. +**[SOLID] For digital training/eval of that target: yes, decisively.** Coarse explicit Euler can invent +a limit cycle the analog system would never exhibit, corrupting both the loss and the equilibrium the +EP gradient is taken at. The correct framing — and the right way to state it in the thesis — is exactly: + +> Analog HW does the true continuous relaxation; the simulator only needs a **faithful + cheap emulator** +> of that relaxation. Adaptive ε, Anderson, and implicit/IMEX are all just *better emulators* — they +> change the simulation's fidelity/cost, not the EP objective or the analog primitive. + +The one asymmetry to keep in mind: **jacreg is NOT in this "just a better emulator" bucket** (it edits +the model the analog HW would run), whereas adaptive ε / Anderson / implicit ARE. That is the precise +sense in which the integrator family is the analog-faithful fix and jacreg is the model-side one. + +### Recommended solver strategy +1. Replace fixed ε=0.1 explicit Euler in the **free** phase with an overshoot/step-rejection adaptive + solver (the corrected `adaptive_eps_calib2.py` logic), with a fail-open floor (Q1). +2. Add **damped Anderson with restarts + residual gating** for both free and nudged phases once the + residual stalls/cycles; solve `F=0` rather than running a fixed Euler count and hoping. +3. Keep **implicit/backward Euler as a reference/fallback**, not the default inner loop (per-step cost). +4. Leave **AsymEP unchanged in principle**: find `z*`, find nudged equilibria, apply `Jv-JTv` at `z*`, + and **gate the update** (`res_gate`, `:153-162`) when residual says no stationary state was found. +5. Retain **jacreg as a bounded adaptive homeostat** (analog settling / true-Hopf insurance), not as the + primary fix. +6. For analog claims, report **solver-independent diagnostics**: force residual `‖F(z*)‖/‖z*‖` (NOT just + the eps-scaled step residual — they differ by a factor of eps, which confounded the eps-sweep), and, + when feasible, the leading continuous `mu` (sign of `Re mu`) and settling/ringing time. + +--- + +## Summary table + +| Option | Eliminates or relocates wall | Changes model? | Analog-faithful? | Verdict | +|---|---|---|---|---| +| (a) adaptive ε | Eliminates if eps_min < eps_crit; else relocates | No | Yes (emulator) | **Primary fix** [SOLID mechanism; bounded-floor guarantee] | +| (b) jacreg | Raises eps_crit (relocates in eps_crit-space) | Yes | No for the wall; yes for settling | **Secondary homeostat / crutch + true-Hopf insurance** | +| (c) smaller fixed ε | Relocates only | No | Yes but inefficient | **Diagnostic / fallback** [SOLID] | +| Anderson | Can eliminate cycle if a root exists | No | Yes (emulator) | **Best practical solver, needs safeguards** | +| Implicit/IMEX | Eliminates (A-stable) | No | Yes (emulator) | **Correct but per-step costly; fallback/reference** | + +Key uncertainties flagged: (1) whether `|Im mu|` saturates vs grows as CE drops is **not directly +measured** — code caps/qknorm/damping argue for bounded, but ε=0.1→0.05 data show it was still rising +in-window; (2) whether s3200 has a true continuous fixed point (g→0) vs only a dead oscillation is +**ambiguous** because the eps-sweep's step-residual ≠ force-residual; the clean continuous-stable +evidence is s2000, not s3200. diff --git a/ep_run/FUGU_Q1_VERDICT.md b/ep_run/FUGU_Q1_VERDICT.md new file mode 100644 index 0000000..03c3c8d --- /dev/null +++ b/ep_run/FUGU_Q1_VERDICT.md @@ -0,0 +1,123 @@ +# Q1 verdict — below-CE-2.1 divergence mechanism + +## Bottom line + +**Refute “conclusive” as stated.** The dossier evidence is **strong and code-consistent** evidence that the s3200 free relaxation has lost the usable attracting fixed point and is in an **attention-gain-driven non-conservative oscillatory regime**. It is **not yet conclusive evidence for a Hopf / Neimark-Sacker bifurcation** specifically. + +The missing piece is **local spectral evidence at an actual fixed point / continued fixed-point branch**: a leading **complex conjugate eigenvalue pair** of the relaxation map + +\[ +G(z)=z+\varepsilon F(z), \qquad M = DG(z^*) = I + \varepsilon J, \qquad J=\partial F/\partial z|_{z^*} +\] + +crossing the unit circle, with real eigenvalue instabilities absent. The current data show the attractor and the causal knob, but not the local bifurcation class. + +## 1. Is Fact 1 + Fact 2 conclusive for Hopf? + +**No.** It is conclusive for a narrower statement: + +> At the redx s3200 checkpoint, with `eps=0.1`, the implemented forward relaxation does not converge from the evaluated embedding state to a fixed point within 6000 steps; instead the one-step residual floors around `~2.3e-2` and oscillates. Reducing the attention output gain by scaling `WO` monotonically shrinks this oscillation and restores convergence by about `alpha=0.2`. + +That is highly consistent with a non-conservative attention-driven oscillatory instability, but it does **not** uniquely identify a Hopf bifurcation. + +### What the evidence establishes + +- **Forward non-convergence / cycle-like attractor:** `eval_relax_s3200.py` applies the same explicit relaxation update as training/eval, records the normalized one-step displacement, and the dossier logs a persistent non-monotone residual floor: about `2.3e-2` after thousands of steps, with tail min/max `2.08e-2 / 2.73e-2`. That is incompatible with ordinary monotone convergence to the free fixed point on that trajectory. +- **Attention gain is the main causal knob:** `knockout_s3200.py` scales `blk.WO` by `alpha`, i.e. scales the attention output contribution, and the dossier logs monotonic shrinkage of the residual floor/oscillation from `alpha=1.0` to `0.2`, where convergence is restored. +- **The code really allows non-conservative oscillatory dynamics:** the thick force contains independent attention projections plus an untied FFN inside an explicit Euler relaxation map. There is no energy/gradient-flow guarantee in the active `attn_mode='thick'` path. + +### The gap + +A Hopf/Neimark-Sacker claim is a **local spectral claim** about the derivative of the map at a fixed point branch. The current facts are **trajectory/knockout facts**: + +- Fact 1 shows a sustained oscillatory forward trajectory at `eps=0.1`; it does not show which eigenvalue crossed first. +- Fact 2 shows that reducing total attention output gain removes the oscillation; it does not isolate the **antisymmetric Jacobian part** `A=(J-J^T)/2`, nor does it rule out other nonlinear or discrete-time routes to an oscillatory attractor. +- Because the reported s3200 `alpha=1` trajectory does not converge, an eigenvalue computed at an arbitrary T1 or cycle point would be only an **instantaneous Jacobian**, not a formal Hopf test unless the underlying fixed point/branch is also identified or continued. + +### Alternatives not yet excluded + +1. **Real-eigenvalue fixed-point loss / saddle-node-like route.** + A real leading eigenvalue of `M` crossing `+1` would indicate loss of contraction along a non-oscillatory mode. The observed limit cycle could then be a secondary nonlinear attractor reached after the fixed point destabilizes or disappears, not the primary Hopf mechanism. + +2. **Discrete Euler artifact.** + The actual implemented dynamics are not continuous-time integration; they are the map `z <- z + eps*F(z)`. If `J` has eigenvalue `nu=a+ib`, the Euler-map eigenvalue is `mu=1+eps*nu`. It is possible to have `a<0` — stable continuous-time linear dynamics — but `|1+eps*nu|>1` at `eps=0.1` because the step is too large. That would be a numerical/discrete relaxation instability, not a true continuous-time Hopf. A real `mu<-1` would be the clean period-2 / flip case. + +3. **FFN contribution.** + The thick `nc_force` treats `attention + FFN` as the non-conservative part, and the FFN is untied. The knockout log itself shows `alpha=0.0` still has a tiny residual/oscillation (`res-floor ~1.3e-3`, `osc ~1.2e-3`), so an FFN-only contribution is not zero. The data do support attention as the dominant driver, but not attention as the exclusive source. + +4. **qknorm / attention nonlinearity contribution.** + The evaluated block has `blk.qknorm=True`, and q/k RMSNorm is inside attention. Scaling `WO` suppresses the whole attention output path, including effects mediated by qknorm. Therefore the knockout does not separate “antisymmetric attention matrix/gain” from the nonlinear qknorm-shaped attention Jacobian. + +So the rigorous conclusion is: + +> **Plausible and likely:** attention-dominated non-conservative complex-mode instability. +> **Not yet proven:** Hopf/Neimark-Sacker crossing of a complex conjugate pair as the bifurcation mechanism. + +## 2. Single cleanest measurement + +**Do the local Jacobian spectrum measurement.** More precisely: compute the leading eigenvalues of both `J = dF/dz` and `M = I + eps*J` on the s3200 checkpoint along a fixed-point branch restored by attention scaling. This is more decisive than a pure epsilon sweep or Floquet analysis, because it directly distinguishes complex-pair Hopf from real-eigenvalue loss and also predicts whether `eps=0.1` is a discrete-Euler artifact. + +### Exact measurement to run later, not now + +Freeze the redx s3200 checkpoint, same batch/sequence and `qknorm=True`. Define + +\[ +F_\alpha(z,x)=-(z-x)+\alpha\,Attn(LN_1 z)+FFN(LN_2 z)-c z +\] + +where `alpha` is implemented exactly as in `knockout_s3200.py` by scaling `WO`. For `alpha` values bracketing the observed transition — at minimum around `0.2`, `0.4`, `0.7`, `1.0`, then refined near the first loss of convergence — do: + +1. Find a true fixed point `z*_alpha` satisfying `||F_alpha(z*)||/||z*||` very small, using long relaxation where it converges and preferably continuation/Newton from the previous `alpha` so the branch can be followed up to the marginal point. +2. At each `z*_alpha`, compute the leading eigenvalues `nu_i` of + `J_alpha = dF_alpha/dz | z*_alpha` + using JVP/Arnoldi or another matrix-free eigensolver. +3. Convert them to relaxation-map eigenvalues + `mu_i = 1 + 0.1 * nu_i`. +4. Record the leading `|mu_i|`, whether the leading pair is complex or real, and for complex `nu=a+ib` also record `a=Re(nu)` and the Euler stability threshold + \[ + eps_crit = -2a/(a^2+b^2) \quad \text{when } a<0. + \] + +### Outcome table + +- **Confirms Hopf / Neimark-Sacker of the implemented relaxation map:** + A complex conjugate pair is the leading spectrum and crosses `|mu|=1` as `alpha` or training step increases; real eigenvalues stay inside the unit circle. The observed oscillation frequency should be compatible with `arg(mu)` per relaxation step. This confirms the map-level Hopf mechanism. + +- **Confirms true continuous-time Hopf rather than Euler artifact:** + The same complex pair has `Re(nu)` crossing through `0` to positive values. Then shrinking `eps` changes the discretization but does not restore continuous-time stability once `Re(nu)>0`. + +- **Shows Euler-step artifact instead:** + The leading pair is complex and `|1+0.1*nu| >= 1`, but `Re(nu) < 0`. Then the continuous-time linearization is damped, while the explicit Euler step is unstable. The predicted stabilizing step is `eps < eps_crit`; an epsilon sweep would be confirmatory, but the spectrum already gives the answer. + +- **Shows real saddle-node / steady instability instead:** + The leading eigenvalue crossing is real near `mu=+1` / `nu=0`. Then the Hopf claim is wrong; the limit cycle is downstream nonlinear behavior after a real fixed-point loss. + +- **Shows flip / two-cycle artifact:** + A real map eigenvalue crosses `mu=-1` or is `< -1`. Then the oscillation is a discrete period-doubling / 2-cycle-type instability, not Hopf. + +- **Shows FFN is materially involved:** + If the unstable/near-unstable pair remains when `alpha=0`, or if the leading antisymmetric contribution is dominated by the FFN block, then “attention antisymmetric part drives it” is overstated. If the pair moves safely inside the unit circle as `alpha` is reduced and disappears with attention removed, then the attention-dominant mechanism is supported. + +Why not make the epsilon sweep the primary measurement? It is useful, but indirect. If smaller `eps` converges, that could indicate an Euler artifact, but it would not by itself distinguish complex Euler instability from real flip or other nonlinear step-size effects. The Jacobian spectrum gives the bifurcation class and the epsilon prediction in one measurement. + +Why not Floquet/period first? Floquet multipliers of the observed cycle would quantify stability of the cycle, and period/frequency could corroborate `arg(mu)`, but they do not identify which fixed-point eigenvalue caused the attractor to appear. Use Floquet/period only as a follow-up. + +## 3. Consistency with the actual code + +The proposed mechanism is **consistent with the implemented force and relaxation map**, with the caveat that the code implicates `attention + FFN` as the active non-conservative block, not mathematically pure attention alone. + +- **The thick force is exactly the stated form.** In `lt_ep_train.py`, `tforce` computes layer-normed attention and FFN and returns `-(z - xin) + self.attn(h1) + ff - self.c * z` (`lt_ep_train.py:81-85`). With `c=1`, this is `xin - 2z + Attn(LN z) + FFN(LN z)`. The autograd-enabled `force` path for `attn_mode == 'thick'` computes the same structure and returns it at `lt_ep_train.py:99-106`. + +- **The relaxation is explicit Euler.** `relax` updates `z` by `z = z + eps * blk.force(z, xin).detach()` (`lt_ep_train.py:123-133`). Therefore the linearized relaxation map is exactly `M = I + eps*J`. + +- **The free phase used by EP is this relaxation state.** `ep_step` embeds the input, computes `zs = relax(..., T1, eps)`, then measures a one-step residual from that state (`lt_ep_train.py:140-145`). The code explicitly records this as the T1 free-phase state before any optional refinement (`lt_ep_train.py:146`). + +- **The attention path is non-conservative in the active model.** Attention uses independent `WQ`, `WK`, `WV`, `WO` projections (`lt_ep_train.py:58-68`), and optional q/k RMSNorm when `blk.qknorm` is set (`lt_ep_train.py:63-65`). The eval scripts do set `blk.qknorm=True` (`eval_relax_s3200.py:8`, `knockout_s3200.py:10`). There is no tied-energy construction in the thick path. + +- **The knockout really scales attention output.** `knockout_s3200.py` loads the same checkpoint and performs `blk.WO.mul_(alpha)` before relaxation (`knockout_s3200.py:9-17`). Thus the logged alpha trend is a legitimate intervention on total attention output gain. + +- **The code itself treats FFN as part of the non-conservative component.** In thick mode, `nc_force` returns `attention + FFN`, not attention alone (`lt_ep_train.py:92-97`). The AEP nudged correction also applies `Jv - JTv` of this `nc_force` in the real/thick modes (`lt_ep_train.py:171-179`). In `holo_ep.py`, the holomorphic and real-axis thick forces match the same `-(z-xin)+att+ff-c*z` structure (`holo_ep.py:36-51`, `holo_ep.py:134-152`), and their AEP correction again uses `Jv-JTv` of `blk.nc_force` (`holo_ep.py:76-84`, `holo_ep.py:176-185`). + +## Final verdict + +**The Hopf story is code-consistent and likely, but not proven.** The current evidence nails an attention-dominated non-conservative forward oscillation at the implemented `eps=0.1`; it does **not** yet nail the bifurcation class. The decisive next measurement is the **leading spectrum of `J` and `M=I+eps*J` on the s3200 fixed-point branch under attention-gain continuation**. A complex conjugate pair crossing `|mu|=1`, with real modes stable and with `Re(nu)` interpreted to rule in/out Euler-step instability, would settle the question. diff --git a/ep_run/FUGU_Q_OPTIONS.md b/ep_run/FUGU_Q_OPTIONS.md new file mode 100644 index 0000000..513e21a --- /dev/null +++ b/ep_run/FUGU_Q_OPTIONS.md @@ -0,0 +1,29 @@ +# Fugu query — fix options for the EP below-CE-2.1 divergence (2026-06-23) + +Read for full chain: EP_DIAGNOSIS_DOSSIER.md, EP_BELOW210_DIAGNOSIS_FIX.md (cont.6/cont.7), FUGU_VERDICT_FULL.md. +Code: lt_ep_train.py (force/tforce :81-106, relax :123 `z=z+eps*blk.force(z,xin).detach()`, ep_step :140, jacreg :211-219), holo_ep.py. + +## CONFIRMED diagnosis (measured this session) +The divergence is the **explicit-Euler instability of the *stiffening* rotating (non-conservative) attention mode**: +- relax is explicit Euler `z←z+εF`, ε=0.1 → stability object = discrete map M=I+εJ. Attention non-conservative (indep WQ/WK/WV/WO) → J has a complex eigenvalue μ=a+ib, **a<0 (continuous-STABLE)**, b≠0. Euler unstable when |1+εμ|>1, i.e. ε>ε_crit=−2a/(a²+b²). EP training makes attention expressive → mode stiffens (b grows) → ε_crit shrinks → at fixed ε it crosses → forward LIMIT CYCLE → loss blows. +- **ε-monotonicity (decisive), 3 training runs identical except ε:** ε=0.1→blew@**2.74**; ε=0.1,t2sel=160 (BETTER gradient, cos0.998)→blew@**3.02** (EARLIER — gradient is NOT the lever); ε=0.05→blew@**2.41**. Smaller ε → strictly lower wall; fixed ε only RELOCATES it (as attention keeps stiffening, even ε=0.05 eventually too coarse). +- **⇒ digital-sim integration artifact:** Re μ<0 ⇒ the CONTINUOUS ODE (analog HW, ε→0) is stable; only the coarse explicit-Euler sim cycles. eval ε-sweep on marginal ckpt s3200 confirms (res-floor shrinks monotonically as ε↓; ε=0.01 converges). +- AsymEP gradient is accurate WHEN a fixed point exists (cos 0.99 vs exact adjoint at tight res); the failure is the FORWARD relaxation losing its fixed point, not the gradient. + +## The 3 candidate fixes +(a) **adaptive ε** [#30]: closed-loop step-size — grow ε when contracting, shrink on residual OVERSHOOT. Integration-axis; SIM-only; no model change / no expressivity cost. +(b) **jacreg**: penalize ‖J_nc·v‖ (non-conservative Jacobian = attention+FFN; lt_ep_train.py:211-219 JVP through blk.nc_force) → reduces |Im μ| → raises ε_crit. MODEL change. Early 2.40-validated runs used *adaptive* jacreg; diverging runs froze it weak. +(c) **smaller fixed ε** (0.02/0.03): just moves the wall down; confirmatory, not a fix. + +## Adaptive-ε calibration so far (eval-only on s3200 cycling-op + s2000 smooth-op; ground-truthed by the ε-sweep) +- NAIVE controller (shrink if g=‖F‖/‖z‖ falls <2%/step): WRONG — parks ε at floor 0.005 on ALL ops, because slow per-step contraction (ρ~0.99) is misread as "ε too big." Conflates small-ε's slow contraction with instability. +- CORRECTED (shrink only on OVERSHOOT g_t>g_{t-1}; grow otherwise): adaptive as desired — s3200 (stiff): ε→0.003-0.008, g floors ~0.10 (≈ε=0.01 benchmark 0.09); s2000 (smooth): converges to **g=0** (true fixed pt), ε grows toward 0.1. Tradeoff in overshoot-tolerance: strict (tol1.0) catches s3200 cycle but over-shrinks near s2000's converged tail (float-noise ticks→spurious shrink); permissive (tol1.02) stays fast on smooth but MISSES s3200's slow cycle (stays ε~0.095, g~0.23). → suggests EMA-smoothed signal / step-doubling local-error / noise-relative tol. +- NOTE: s3200 g floors ~0.09 even at tiny ε (genuinely no fixed point at the marginal op, OR just slow finite-step convergence — ambiguous); s2000→g=0 cleanly. + +## QUESTIONS — answer Q1→Q3 in order, decisive, grounded in code/data, flag solid-vs-uncertain. + +**Q1 — Evaluate (a)/(b)/(c).** Which ELIMINATES the wall vs merely RELOCATES it? Which fits the *analog* target (continuous relaxation)? Rank + justify. Is adaptive ε guaranteed to eliminate the wall, or does its ε_min floor just relocate it (like fixed small ε) IF the rotating mode stiffens without bound — and does b/|Im μ| stiffen unboundedly as CE drops, or saturate? + +**Q2 — The jacreg paradox.** If the divergence is fundamentally a *simulation* discrete-Euler artifact (Re μ<0, continuous/analog stable), WHY does jacreg work? Is it just **raising ε_crit** (cutting |Im μ| so fixed ε becomes stable) — fixing the SAME discretization wall from the model side — or fixing a genuine continuous-time problem? **Does jacreg's benefit transfer to analog hardware** (continuous, no ε): (i) papering over a sim artifact analog wouldn't have, or (ii) also improving the continuous dynamics (faster/less-ringing settling — useful on analog)? So: is jacreg a "real fix" or a sim-crutch, relative to adaptive ε? + +**Q3 — Anderson / implicit integrators.** Can Anderson acceleration or implicit/semi-implicit (IMEX) integration replace explicit Euler to kill the discretization instability? (i) Compatible with **AsymEP** — which needs free equilibrium z*, nudged equilibrium, and the local Jacobian at z* (the −2A correction)? Does changing the forward integrator break the EP gradient estimator (which assumes relaxation reaches z*)? (ii) Implicit Euler's per-step nonlinear solve — tractable for a transformer block, or self-defeating (the solve is itself a relaxation)? (iii) Anderson on the fixed-point iteration — only speeds convergence, or also STABILIZES (suppresses the limit cycle)? (iv) Does the integrator choice matter for the **analog** target (continuous, no integrator), or is this purely sim-side acceleration — i.e. is the right framing "analog HW does the true continuous relaxation; the sim just needs a faithful+cheap emulator, and adaptive-ε / Anderson / implicit are all just better emulators"? diff --git a/ep_run/FUGU_VERDICT_FULL.md b/ep_run/FUGU_VERDICT_FULL.md new file mode 100644 index 0000000..3c06293 --- /dev/null +++ b/ep_run/FUGU_VERDICT_FULL.md @@ -0,0 +1,160 @@ +# FUGU_VERDICT_FULL — Q1–Q4 + +## Q1 — Mechanism: confirm/refute the non-conservative Hopf claim + +**Verdict:** confirm the broad failure mode, but do **not** overclaim the exact bifurcation label yet. The code/data are conclusive for an **attention-dominated non-conservative forward oscillatory instability with no usable fixed point** at redx `s3200`. They are **not yet conclusive** that the route is specifically a local continuous-time Hopf bifurcation of a fixed point. The best current statement is: + +> The implemented relaxation map `z_{t+1} = z_t + eps*F(z_t)` has crossed from a stationary computation into an attention-driven oscillatory attractor. The most likely local mechanism is a complex-conjugate pair of the map Jacobian `M = I + eps*J` crossing `|lambda| = 1` — a Hopf/Neimark-Sacker-type instability of the Euler relaxation. But the eigenvalue crossing has not yet been measured, so the exact bifurcation class remains a hypothesis. + +Grounding in the code: + +- `relax` is explicit Euler: `z = z + eps * blk.force(z, xin).detach()` in `lt_ep_train.py:123-133`. Therefore the relevant stability object for the implemented computation is the **discrete map** `M = I + eps*J`, not only the continuous vector field `J=dF/dz`. +- In the relevant `attn_mode='thick'` branch, `tforce` / `force` implement + `F(z) = -(z - xin) + Attn(LN1(z)) + FFN(LN2(z)) - c*z` (`lt_ep_train.py:81-85`, `102-106`). With `c=1`, the passive term is `xin - 2z`; learned attention/FFN Jacobian must fit inside that contraction margin. +- Attention is genuinely non-conservative in the implementation: independent `WQ/WK/WV/WO`, causal softmax, optional q/k RMSNorm (`qknorm`) in `lt_ep_train.py:58-68`. It is not the gradient of the tied conservative `attn_energy` path. +- In thick mode, `nc_force` includes **attention plus the untied FFN** (`lt_ep_train.py:92-97`). Thus the knockout supports “attention is dominant,” but the code also explains why a tiny FFN-only oscillation can remain when attention output is zeroed. + +What the measurements prove: + +1. **It is not slow convergence.** At s3200 the residual decays initially and then floors/oscillates around `~2.3e-2` through 6000 relaxation steps, with non-monotone tail `2.08e-2` to `2.73e-2`. That rules out the earlier “rho close to one but still convergent” framing for the actual forward computation. +2. **Attention is causally responsible for the large cycle.** Scaling `WO` monotonically shrinks the oscillation: `alpha=1.0` cycles, `0.7` cycles smaller, `0.4` is nearly gone, and `0.2` restores a true fixed point. That is strong causal evidence that attention’s non-conservative/gain component drives the failure. +3. **The estimator is not the primary explanation once no fixed point exists.** `ep_step` assumes `zs = relax(...)` is a free equilibrium and forms the AEP/nudged update around it. The dossier says AsymEP is accurate when a fixed point exists; at s3200 the required object is absent. + +What remains unexcluded: + +- **Discrete Euler artifact vs continuous Hopf.** For a continuous eigenvalue `mu=a+ib` of `J`, the Euler multiplier is `lambda=1+eps*mu`; stability requires `(1+eps*a)^2 + (eps*b)^2 < 1`. A stiff rotating mode with `a<0` can still have `|1+eps*mu|>1` at `eps=0.1`. Then the digital relaxation cycles even if the underlying continuous-time analog ODE would converge for smaller `eps`. +- **Real-multiplier alternatives.** A real `lambda` crossing `+1` would indicate saddle-node/pitchfork/loss of stationary solutions; a real `lambda` crossing `-1` would indicate a flip/period-doubling route. The observed smooth oscillation and attention scaling favor a complex pair, but do not prove one. +- **Global/coexisting-attractor route.** The long relaxation proves that the trajectory from the embedding does not settle to a stationary computation. It does not, by itself, prove the cycle emerged through a local fixed-point Hopf rather than a global basin/coexisting-attractor mechanism. +- **FFN contribution.** Because `alpha=0` still leaves a tiny cycle and thick-mode `nc_force` includes the FFN, the precise claim is “attention-dominated,” not “attention-only.” + +**Single cleanest confirming measurement:** perform an **attention-output-scale continuation with leading eigenvalues of the actual Euler map `M=I+eps*J` at the converged fixed point just below the transition**. + +Concretely, for the s3200 checkpoint: set `WO <- alpha*WO`, solve to tight fixed-point residual for subcritical `alpha`, compute leading eigenvalues of `J=dF/dz` and `M=I+eps*J` at `z*(alpha)`, and increase `alpha` until convergence is lost. This is cleaner than eigenvalues at an arbitrary point on the already-existing cycle, because Hopf is a fixed-point stability statement; Floquet analysis is useful second, but characterizes the cycle after it exists. + +Outcomes: + +- **Complex pair of `M` reaches/crosses `|lambda|=1` at the same `alpha_c` where the fixed point disappears:** confirms the Hopf/Neimark-Sacker mechanism for the implemented relaxation map. +- **The corresponding continuous eigenvalues `mu` have `Re(mu)` crossing zero:** confirms a true continuous-time Hopf, relevant to analog ODE hardware. +- **`|1+eps*mu|>1` while `Re(mu)<0`, and smaller `eps` restores convergence:** the failure is mainly an explicit-Euler/stiff-rotation artifact, not a continuous-time Hopf. +- **A real `lambda>=1`:** not Hopf; look for saddle-node/pitchfork/loss of stationary solution. +- **A real `lambda<=-1`:** not Hopf; a discrete flip/period-doubling route is implicated. +- **All fixed-point multipliers stay inside the unit circle up to loss of convergence:** likely global/coexisting attractor or basin issue rather than local Hopf. +- **Floquet multipliers of the observed cycle all stable except phase:** confirms a stable limit cycle, but still does not identify how the stationary solution was lost. + +So Q1: **confirm attention-driven non-conservative oscillatory non-convergence; keep “Hopf” as the leading, not-yet-proven, local mechanism until the fixed-point eigenvalue continuation is measured.** + +--- + +## Q2 — Fix: keep the operator below Hopf while preserving expressivity + +**Verdict:** the best immediate fix is a **residual-triggered adaptive stability homeostat**, implemented primarily with adaptive `jacreg`, plus `qknorm` and modest attention-gain/spectral guardrails. Direct cycle/residual penalties should be alarms/gates, not the primary shaping objective. A structural `r_str` parameterization is the cleanest long-term analog design, but it is less immediately surgical for the current transformer attention code. + +Important detail: the current `jacreg` is not a pure antisymmetric penalty. In `lt_ep_train.py:211-219`, it estimates `||J_nc v||^2 / ||v||^2` by a JVP through `blk.nc_force`; in thick mode `J_nc` is attention plus FFN (`lt_ep_train.py:92-97`). Thus it penalizes learned non-conservative/gain response — a proxy for dangerous rotating dynamics — not exactly `||(J-J^T)/2||`. That proxy is nevertheless the best-supported control knob in the dossier. + +Why adaptive `jacreg` is the right primary fix: + +- It targets the learned recurrent Jacobian that must remain inside the passive `-2z` contraction margin. +- The controller is already wired to the right observable: the free-phase residual. `lt_ep_train.py:520-529` increases `jr` when `res/res_target` rises and relaxes it when dynamics settle. +- The failure is abrupt; a fixed weak penalty can allow training to walk past the bifurcation. The penalty must adapt to residual/cycle onset. +- The dossier states the validated stable runs used adaptive `jacreg`, while diverging runs froze it weakly. +- It preserves more expressivity than simply shrinking all attention: attention can remain strong where it does not destroy the fixed point. + +Role of the other candidates: + +- **Structural `r_str` bound:** best theoretical/hardware guarantee if the recurrent operator can be parameterized as bounded symmetric plus bounded antisymmetric components. But in this code the attention state-Jacobian is data-dependent through LN, q/k projections, softmax, values, and `WO`; a simple weight-level `r_str` does not directly bound the actual rotating eigenvalues. Use this for redesign, not as the immediate rescue. +- **Gain control / `gamma` / `qknorm`:** necessary guardrail, insufficient alone. `qknorm` is already enabled in the s3200 scripts, yet `alpha=1` cycles. The `WO` knockout proves gain matters; use gain caps, but do not rely on blunt global gain reduction as the main solution. +- **Direct cycle-amplitude / residual penalty:** extremely analog-measurable, but symptom-level. It activates when the operator is already near/off the stationary manifold and may punish slow-but-stable modes. Use it to gate invalid EP updates and drive the homeostat. +- **Log-norm / contraction penalty:** theoretically stronger if computed in the right metric, but global and expensive; less obviously forward-local. Use as an offline diagnostic or occasional calibration, not the main analog training primitive. + +Concrete recipe: + +1. **Keep `qknorm` on** for thick attention. It bounds q/k logits and reduces Jacobian stiffness, but is not sufficient by itself. +2. **Initialize inside the basin.** Use small residual-branch initialization (`resinit < 1`, scaling `WO` and `pj`) and keep `c=1` or stronger initial leak. +3. **Use adaptive `jacreg` with nonzero floor and enough ceiling.** Keep the existing controller structure. Set `res_target` well below the measured cycle floor, roughly `1e-3` to `5e-3`; keep `jr_max` high enough to recover, e.g. the code’s `16` scale; use residual EMA to avoid controller thrash. +4. **Turn on `res_gate`.** If free/refined residual exceeds the validity gate, skip task EP/nudge gradients and apply only stabilization. Since the observed cycle floor is around `2.3e-2`, a gate of order `5e-3` to `1e-2` is appropriate. +5. **Prefer branch-aware regularization if modifying code.** Penalize attention’s learned Jacobian more strongly than FFN, because the knockout identifies attention as dominant; keep a lighter FFN penalty because the FFN-only tiny cycle exists. +6. **Add slow attention-output gain rails.** Because post-hoc `WO` scaling restores convergence at `alpha=0.2` and is near-safe by `alpha=0.4`, impose a soft cap/homeostat on `WO` or attention-output spectral/branch gain. Use it as a rail, not the primary objective. +7. **Use `resreg` only as secondary T1 protection.** The `resreg` term (`lt_ep_train.py:220-231`) protects the finite `T1` state used by evaluation/BPTT, but it does not replace fixed-point stability control. +8. **Monitor tail oscillation, not only a short residual probe.** The previous false “slow convergence” framing came from seeing transient decay and missing the residual floor. Track `T1` residual plus tail min/max or autocorrelation. + +Analog-realizable version: + +- Measure `||z_{t+1}-z_t||/||z_t||` or continuous `||F(z)||/||z||` locally during settling. +- Approximate `jacreg` forward-only by injecting small random state perturbations `eta` and measuring `F_nc(z+delta eta)-F_nc(z)`; use that to locally reduce attention/FFN array gains or asymmetry budgets. +- Do not require exact software `vjp`/`J-J^T` as the hardware primitive unless the substrate supports reciprocal probes. For hardware, use forward perturb-and-measure gain/curl proxies plus residual gating. + +So Q2: **primary fix = adaptive `jacreg`-style stability homeostasis; guardrails = `qknorm`, small residual initialization, spectral/gain caps, and `res_gate`; long-term clean analog design = structural `r_str/gamma` bounded operator.** + +--- + +## Q3 — Thesis: can sub-Hopf non-conservative attention be expressive enough? + +**Verdict:** yes, at least for this architecture/scale. The data show a practical expressivity-vs-stability tradeoff in rotating/gain budget, but not a fundamental theorem that coherent language requires post-Hopf dynamics. + +The right thesis is: + +> Non-conservative attention can be expressive below the Hopf boundary, but it must operate with a measured stability margin. Beyond that margin, the model becomes an oscillator rather than a valid equilibrium language model. + +Evidence: + +- Exact BPTT on the identical model trains cleanly to CE `1.83` and does not drive the forward operator into the cycling regime. That strongly suggests the architecture contains stable expressive LM solutions. +- The dossier says AsymEP matches the exact adjoint when a true fixed point exists. Therefore the failure is not an inherent fixed-point gradient ceiling; it is the EP trajectory crossing the stationary-computation boundary. +- The knockout gives a local threshold estimate at s3200: `alpha=1.0` cycles, `0.7` cycles, `0.4` is nearly gone but still floored, `0.2` converges. For this checkpoint/batch/`eps=0.1`, the strict fixed-point critical attention-output scale is roughly between `0.2` and `0.4`, plausibly near `0.3` of the trained s3200 attention-output gain. + +In `r_str/gamma` terms: + +- The code does not measure `r_str` directly, so a numeric `r_str` threshold would be fake precision. +- The dangerous quantity is the **effective rotating learned Jacobian relative to the contraction margin**: roughly asymmetric/rotating fraction times total attention/FFN gain, divided by the passive damping from `-(1+c)z`. +- For the implemented Euler map, the boundary is `|1+eps*mu|=1`, not just `Re(mu)=0`. With `eps=0.1`, high imaginary frequency can destabilize the map even when continuous-time damping remains negative. +- Operationally, the threshold is wherever the leading complex multiplier of `M=I+eps*J` reaches one. In the s3200 `WO`-scale coordinate, stay on the `alpha<=0.2-0.3` side for a strict fixed-point criterion unless the weights re-adapt under regularization. + +Does the sub-threshold regime suffice? + +- **For a coherent small LM:** yes, the BPTT result is strong evidence. +- **For maximum transformer expressivity:** stability imposes a cost. It limits sharp recurrent routing, high non-normal amplification, and strong directed cycles. Extra capacity should come from width, depth, longer but stable settling, or controlled feedforward correction, not uncontrolled curl. +- **For current AsymEP:** stable runs around CE `2.40` do not prove a ceiling; they show the present local training/control recipe has not yet matched BPTT. + +Is a hybrid the ceiling? + +**Likely yes for competitive analog language hardware.** The realistic design is: + +1. a bounded-asymmetry equilibrium core that stays sub-Hopf and supports exact AsymEP; +2. non-conservative attention inside a measured `r_str/gamma` or spectral-margin budget, with qknorm and gain homeostasis; +3. a thin explicit correction/readout/feedforward/digital-clocked path for operations that would otherwise require too much recurrent curl. + +So Q3: **sub-Hopf non-conservative attention can be expressive enough; the tradeoff is real but practical, not proven fundamental. The local s3200 threshold is `alpha_c ~ 0.2-0.4` — probably near `0.3` — or, generally, the point where the leading complex multiplier of `M` hits the unit circle. A bounded-asymmetry core plus thin correction is the realistic ceiling.** + +--- + +## Q4 — Primitive: equilibrium AsymEP or native non-equilibrium learning? + +**Verdict:** for this codebase and near-term analog hardware, the right primitive is **equilibrium computation kept below Hopf**. Preserve a stationary state `z*`; then AsymEP is exact in the regime it assumes. Do not embrace the accidental limit cycle as the main primitive unless you replace the objective, readout, and learning rule. + +Why the current code requires a stationary state: + +- `ep_step` first computes `zs = relax(...)` and treats it as the free equilibrium (`lt_ep_train.py:142-146`). +- The AEP correction uses `v = z - zs`, `Jv`, and `JTv` at `zs` (`lt_ep_train.py:172-178`). That is a local stationary-state linearization. A phase point on a cycle is not the `z*` required by the implicit-gradient formula. +- The block parameter gradient is formed from `(a * f).sum()` with `f = blk.force(zs.detach(), xin, cg=True)` (`lt_ep_train.py:202-205`). If no fixed point exists, this is not the derivative of a stationary solve. +- `holo_ep.py` has the same assumption: `holo_a`, `holo_a_select2`, and `holo_a_track` expand nudged trajectories around `zs`; `holo_a_lockin` is a demodulated nudge estimator around a base state, not a full learning rule for a self-sustained free limit cycle. + +What a true non-equilibrium route would require: + +1. define the computation as a periodic orbit, phase-averaged state, invariant measure, or reservoir trajectory instead of `z*`; +2. define CE on a time/phase average, not an arbitrary T1 sample; +3. handle the neutral phase mode of the orbit; +4. replace fixed-point adjoints with Floquet/Poincare adjoints, eligibility traces, perturbation learning, or lock-in demodulation over periods; +5. keep the orbit stable while learning all recurrent attention weights; +6. demonstrate locality, forward-only operation, and sample efficiency for language-scale credit assignment. + +Is route (ii) tractable forward-only? + +- **For fixed reservoirs plus readout:** yes; oscillatory analog reservoirs can be useful and trainable with local/readout or perturbation rules. +- **For full recurrent attention weights:** not currently as a clean replacement for AsymEP. It is possible as a research program — phase-demodulated perturbation learning, e-prop-like traces, Floquet-local approximations — but it will be approximate, noisy, phase-sensitive, and likely much less efficient. +- **For this project:** training through the s3200 cycle with existing EP/holo estimators is invalid. Gate task updates when residual indicates loss of equilibrium; apply stabilizing homeostasis; resume only after a fixed point is restored. + +Analog-hardware conclusion: + +- Equilibrium analog hardware has a simple primitive: settle, nudge, measure local contrast, update. +- A limit-cycle primitive requires clocks/phase references, demodulation windows, eligibility storage, and phase-stable credit assignment. That may be viable for special-purpose oscillatory reservoirs, but it gives up the main simplicity and exactness of EP. + +So Q4: **keep the operator below Hopf.** Treat non-equilibrium oscillatory learning as a separate reservoir/auxiliary research direction, not the central primitive for this AsymEP transformer. The system should be a bounded non-conservative equilibrium machine, not an accidental oscillator. diff --git a/ep_run/GPT55_BUG_HUNT.md b/ep_run/GPT55_BUG_HUNT.md new file mode 100644 index 0000000..241adba --- /dev/null +++ b/ep_run/GPT55_BUG_HUNT.md @@ -0,0 +1,249 @@ +# GPT55 EP/AEP Correctness Bug Hunt + +Scope: static review of `lt_ep_train.py`, `holo_ep.py`, `asym_probe.py`, with context from `../EP_BELOW210_DIAGNOSIS_FIX.md`. I did not modify training code and did not run training or GPU jobs. I used only tiny CPU checks for formula consistency. + +## Executive Summary + +The main suspect for the reported `cos(g_EP, g_transpose) ~= 0.94` plateau that is flat over `hr` is not beta noise. The code has no nudged-phase/adjoint convergence check. A tight free-phase residual does not imply the two-phase AEP contrast has converged to the fixed-point adjoint. In the linearized corrected dynamics, finite-`T2` error is independent of `hr`, exactly matching the clue. + +I did not find a sign flip in the deterministic AEP correction. The real two-phase contrast, holomorphic two-phase contrast, and exact-adjoint probe sign conventions are mutually consistent. I did find several real correctness hazards: the plain-EP AEP correction clip invalidates the transpose correction, stochastic `fnoise` makes JVP/VJP not derivatives of the same force realization, `t1max` trains at a refined fixed point while evaluation/BPTT use the finite `T1` state, and `asym_probe.py` has probe-specific footguns that can mislead the diagnosis. + +## Ranked Findings + +### 1. Finite nudged/adjoint relaxation is unchecked; this explains an `hr`-flat cosine plateau + +Files/lines: +- `lt_ep_train.py:163-180` runs plain nudges for exactly `T2` steps. +- `lt_ep_train.py:181-197` computes holomorphic/AEP `a` through `holo_a_track`, `holo_a_select2`, `holo_a_select`, or `holo_a`, but ignores whether the nudged contrast actually converged. +- `holo_ep.py:179-185` updates the two real phases for one finite step at a time. +- `holo_ep.py:193-211` selects a snapshot by smallest inter-snapshot increment, not by an adjoint residual. +- `holo_ep.py:229-254` does the same for common-mode tracking. +- `asym_probe.py:727-743` changes `beta/hr` in the diagnostic sweep but keeps the nudged relaxation budget fixed. +- `asym_probe.py:582-595` interprets beta-sweep behavior without a finite-`T2` branch. + +What is wrong: +The estimator validity logic checks/refines the free phase, but not the nudged phase. For the corrected local dynamics, with `M = I + eps * J^T` and `ell = dL/dz`, the two-phase contrast approximately obeys: + +```text +a_{t+1} = M a_t + eps * ell +fixed point: J^T a = -ell +``` + +The finite-`T2` error is `M^T2` times the initial adjoint error and is independent of the nudge radius `r/hr` in the linear regime. Therefore a cosine plateau that is flat across `hr=0.04..0.8` is a signature of under-converged nudged/adjoint relaxation, not evidence that beta noise or finite-beta bias has been ruled out. + +Why it matters: +This corrupts training gradients directly. It also corrupts the diagnosis because the probe can report a tight free-phase residual while the AEP contrast is still a truncated adjoint solve. As the operator hardens below CE about 2.1, the adjoint relaxation can slow down even if the free phase is tight. + +Severity: corrupts TRAINING and DIAGNOSIS/probe. + +Confidence: high. The line-level behavior matches the `hr`-insensitive clue exactly. + +Minimal fix: +Add a nudged/adjoint convergence criterion. At minimum, return and log `inc_min / (||a_best|| + eps)` from `holo_a_select2` and `holo_a_track`, and sweep `t2sel` in `asym_probe.py` at fixed `hr`. Better: compute an adjoint residual proxy at the selected `a`, e.g. `||J^T a + ell|| / (||ell|| + eps)` using VJP of the full thick force at the free state, and keep nudging until it is below tolerance. Treat `hr` sweeps as inconclusive unless `T2/t2sel` convergence is also demonstrated. + +### 2. Plain-EP AEP correction is norm-clipped, which invalidates the transpose correction + +Files/lines: +- `lt_ep_train.py:172-178` computes `corr = Jv - JTv`, then replaces it by `corr * (fn / cn)` when `||corr|| > ||f||`. + +What is wrong: +The AsymEP correction is algebraic: subtract `Jv - J^T v` so the local Jacobian becomes `J^T`. Scaling that correction by a state-dependent factor changes the corrected dynamics to something like `J - alpha(J - J^T)`, with `alpha < 1` exactly when asymmetry is large. That is no longer the transpose dynamics. + +Why it matters: +This can create a systematic estimator bias in plain EP (`holo=0`). It is especially dangerous near the regime where the antisymmetric correction is large, which is the regime this correction is supposed to fix. + +Severity: corrupts TRAINING for the plain real two-phase path. It does not affect the current `holo_a_select2`/`holo_a_track` path, which does not use this clip. + +Confidence: high. + +Minimal fix: +Remove the correction clip. If stability is needed, clip the total update, reduce `eps`, reduce `beta/hr`, or reject/halve the nudged step while logging that the estimator left its validity region. Do not scale only the antisymmetric correction. + +### 3. With `fnoise > 0`, JVP/VJP are not derivatives of the same force realization + +Files/lines: +- `lt_ep_train.py:87-90` samples fresh multiplicative noise inside `_noisy`. +- `lt_ep_train.py:92-97` applies `_noisy` inside `nc_force`. +- `lt_ep_train.py:171-178` evaluates the noisy force and then separately calls JVP/VJP through `nc_force`. +- `holo_ep.py:150-151` injects fresh noise into `rforce`. +- `holo_ep.py:176-185` and `holo_ep.py:224-239` use `blk.nc_force` for correction JVP/VJP, which can sample different noise again. + +What is wrong: +When `fnoise > 0`, the forward force, JVP force, and VJP force are separate random functions. The correction is no longer `Jv - J^T v` for the same operator used in the state update, and JVP and VJP are not even transposes of the same sampled Jacobian. + +Why it matters: +This corrupts AEP training in the noisy hardware simulation path and makes `navg` average a mixture of stochastic bias and stochastic noise. With `fnoise=0`, this issue is inactive. + +Severity: corrupts TRAINING when `--fnoise > 0`; otherwise inactive. + +Confidence: high. + +Minimal fix: +Do not sample inside differentiable force functions. Sample a fixed noise mask/device realization outside and pass it into both the forward force and the JVP/VJP force, or keep dynamic per-pass noise out of AEP correction and use deterministic mismatch for differentiable hardware probes. + +### 4. `t1max` trains the EP task gradient at a refined state while eval/BPTT use the finite-`T1` state + +Files/lines: +- `lt_ep_train.py:143-152` first computes `zT` at `T1`, then optionally refines `zs` up to `t1max`. +- `lt_ep_train.py:203-210` computes the EP block gradient and readout gradient at refined `zs`. +- `lt_ep_train.py:260-265` BPTT differentiates exactly `T1` unrolled steps. +- `lt_ep_train.py:279-286` validation evaluates exactly `T1` relaxed steps. + +What is wrong: +With `t1max > T1`, EP optimizes the refined fixed-point state, while the reported validation objective and BPTT reference use the finite-`T1` state. That is a real objective mismatch. + +Why it matters: +If `z_T1` drifts away from the refined fixed point, EP can improve the wrong state while validation and the practical finite-time model degrade. This was also identified in `../EP_BELOW210_DIAGNOSIS_FIX.md:11-20`. + +Severity: corrupts TRAINING when `t1max > T1` and the finite-`T1` state is the real objective. + +Confidence: high. + +Minimal fix: +Choose one objective and make all paths use it. If the objective is finite `T1`, compute the EP gradient/readout gradient at `zT` or add a principled finite-time/contraction term. If the objective is the fixed point, evaluate and compare BPTT against the same refined state. + +### 5. `resreg` is hard-wired to thick `tforce` and its scaling includes already-added non-task gradients + +Files/lines: +- `lt_ep_train.py:220-224` computes the residual penalty with `blk.tforce(zT, xin0)`. +- `lt_ep_train.py:225-228` scales by `gtask` after prior task and `jacreg` gradients may already be in `grads`. + +What is wrong: +`blk.tforce` is the thick-block force only (`lt_ep_train.py:81-85`). If `resreg` is used with `attn_mode` other than `thick`, the residual penalty is for the wrong dynamics. Also, `gtask` is described as task-gradient norm but includes any gradients already added to `grads`, including `jacreg` from `lt_ep_train.py:211-219`. + +Why it matters: +This can apply a residual penalty in the wrong direction for non-thick modes and makes the `resreg` ratio slightly different from its stated meaning. + +Severity: corrupts TRAINING conditionally: non-thick `resreg` is high risk; thick with `jacreg` is a smaller scaling bug. + +Confidence: high. + +Minimal fix: +Guard `resreg` with `assert blk.attn_mode == 'thick'` or compute the residual through `blk.force(..., cg=True)` for the active mode. Capture the pure task-gradient norm before adding `jacreg` and `resreg`. + +### 6. Complex masked softmax is numerically unstable because masked logits affect the row shift + +Files/lines: +- `holo_ep.py:26-29` computes `c = a.real.amax(...)` before masking, then multiplies `exp(a - c)` by `mask`. +- `holo_ep.py:48` passes the causal mask as a complex tensor into this helper. +- Real attention masks before softmax at `lt_ep_train.py:66-68`. + +What is wrong: +Mathematically, the row shift cancels if arithmetic is exact. Numerically, a large masked future logit can dominate `c`, causing all valid entries to underflow or lose precision. The real path masks before softmax and does not have this issue. + +Why it matters: +With `qknorm` enabled this is mitigated because logits are bounded, but without `qknorm` it can bias or NaN holomorphic phases. + +Severity: corrupts TRAINING conditionally in complex holomorphic paths, especially without `--qknorm`. + +Confidence: medium. + +Minimal fix: +Keep the mask boolean and compute the shift over valid entries only: + +```python +c = a.real.masked_fill(~mask, -float("inf")).amax(-1, keepdim=True) +w = torch.exp(a - c).masked_fill(~mask, 0) +``` + +### 7. Holomorphic EP helpers silently implement only the thick force, but `ep_step` allows them for any mode + +Files/lines: +- `lt_ep_train.py:181-197` calls `holo_ep` whenever `holo > 0`, without checking `blk.attn_mode`. +- `holo_ep.py:36-51` implements `cforce` as thick LN + attention + FFN. +- `holo_ep.py:134-152` implements `rforce` as the same thick real-axis force. +- `lt_ep_train.py:349` default `--attn_mode` is `real`, while `lt_ep_train.py:358-359` allow `--holo`. + +What is wrong: +If a user runs `--holo` with `attn_mode=real`, `energy`, or `mono`, the nudged force used to estimate `a` is not the model force. + +Why it matters: +This silently corrupts training for a legal CLI flag combination. + +Severity: corrupts TRAINING for non-thick `--holo` runs. + +Confidence: high. + +Minimal fix: +Add a hard guard in `ep_step`: `if holo > 0 and blk.attn_mode != 'thick': raise ValueError(...)`, or implement holomorphic force extensions for the other modes. + +### 8. `asym_probe.py` hard-codes model construction choices that may not match the checkpoint + +Files/lines: +- `asym_probe.py:31-50` exposes `--gelu`, `--T1`, `--T2`, `--hr`, etc. +- `asym_probe.py:105-119` forces `attn_mode="thick"`, `c=1.0`, `qknorm=True`, `fnoise=0.0`, `track=True`, and assigns `blk.gelu = cfg.gelu`. +- `lt_ep_train.py:81-120` never reads `blk.gelu`; GELU is hard-coded to tanh-form in the active thick force. + +What is wrong: +The probe can analyze a different model than the checkpoint was trained with. The `--gelu` flag is especially misleading because assigning `blk.gelu` has no effect in the current `EQBlock`. + +Why it matters: +For the current qknorm/thick/c=1/tanh runs this is probably harmless. For c-bump, non-qknorm, non-thick, or historical erf/tanh comparisons, it can make `g_transpose`, `g_BPTT`, and `g_EP` refer to the wrong dynamics. + +Severity: DIAGNOSIS/probe only, unless probe conclusions are used to choose training changes. + +Confidence: high. + +Minimal fix: +Save the training config in checkpoints and load `attn_mode`, `c`, `qknorm`, GELU mode, and relevant flags from it. Remove `--gelu` or implement it in `EQBlock`. + +### 9. `asym_probe.py` labels `ep_step`'s returned residual as estimator/free-phase convergence, but it is the pre-refinement `T1` residual + +Files/lines: +- `lt_ep_train.py:143-152` computes `res` at `T1`, then may refine `zs` and store `res_used`. +- `lt_ep_train.py:232` returns `res`, not `res_used`. +- `asym_probe.py:840` prints that value as `EP estimator free-phase residual from ep_step`. +- `asym_probe.py:505-522` separately computes and prints the refined exact-reference residual. + +What is wrong: +The probe can conflate three different residuals: `T1` residual, refined free-phase residual, and nudged/adjoint residual. Only the last one diagnoses whether the EP contrast has converged. + +Why it matters: +This can make a run look "tightly converged" or "not tightly converged" depending on which print line the reader tracks. It also reinforces the wrong conclusion that free-phase convergence alone validates the estimator. + +Severity: DIAGNOSIS/probe only. + +Confidence: high. + +Minimal fix: +Return both `res_T1` and `res_refined` from `ep_step`, print both in the probe, and add a separate nudged/adjoint residual for `a`. + +### 10. `holo_ep.py` self-test/debug main is broken by unreachable code and an undefined function + +Files/lines: +- `holo_ep.py:257-280` defines `holo_a_lockin` and returns. +- `holo_ep.py:281-290` contains unreachable code that looks like a missing `holo_grads` function body. +- `holo_ep.py:329-332` calls `holo_grads`, which is not defined. + +What is wrong: +Running `python holo_ep.py` as a diagnostic script will fail. + +Why it matters: +This does not affect `lt_ep_train.py` imports of `holo_a`, `holo_a_select2`, or `holo_a_track`, but it can break or mislead standalone estimator checks. + +Severity: DIAGNOSIS/probe only. + +Confidence: high. + +Minimal fix: +Move `holo_ep.py:281-290` into a real `def holo_grads(...)` or delete the stale self-test. + +## Checked And Found Correct + +- AEP correction sign: `lt_ep_train.py:171-178`, `holo_ep.py:181-185`, and `holo_ep.py:233-239` subtract `Jv - J^T v`, which is the correct sign for making the local differential dynamics use `J^T`. +- Two-phase contrast sign: `lt_ep_train.py:199-200` and `holo_ep.py:193-195` compute `(z_- - z_+) / (2 beta/r)`, which matches `lambda` solving `J^T lambda = -dL/dz`. +- Exact-adjoint probe sign: `asym_probe.py:443-445` solves `J^T lambda = -ell`, and `asym_probe.py:457-465` computes `lambda^T F_theta`. That is the correct implicit fixed-point gradient. +- Deterministic force consistency for thick mode: `lt_ep_train.py:81-85`, `lt_ep_train.py:102-106`, `holo_ep.py:134-152`, and `holo_ep.py:36-51` match on the real axis. Tiny CPU check with qknorm gave max `|tforce-rforce| = 1.19e-7` and max `|tforce-cforce.real| = 2.09e-7`. +- GELU consistency in current code: `lt_ep_train.py:84`, `lt_ep_train.py:96`, `lt_ep_train.py:105`, `holo_ep.py:32-33`, and `holo_ep.py:148` all use tanh-form GELU. Tiny CPU check found max difference from `F.gelu(..., approximate='tanh')` of `2.38e-7`. +- qknorm consistency in current thick paths: real attention uses qknorm at `lt_ep_train.py:63-65`, complex force at `holo_ep.py:44-46`, and real nudged force at `holo_ep.py:142-144`. +- Common-mode AEP anchor: `holo_ep.py:231-239` correctly computes `zbar`, duplicates it as the anchor, and applies the antisymmetric correction to `Z - zbar`. +- Block-parameter gradient scope: `lt_ep_train.py:203-205` computes `grad((a * f).sum(), blk.block)`, and `asym_probe.py:457-465` uses the same clamp-gradient path for token/position parameters. This is the right scope for force parameters. +- Readout-head gradient: `lt_ep_train.py:208-210` computes only the direct CE gradient for `Wh`. Since `Wh` is not in the force, there is no missing implicit force term and no double-counting. +- BPTT probe unroll: `asym_probe.py:819` uses `bptt_step`, which unrolls the same `blk.force` update as training at `lt_ep_train.py:260-265`. +- `no_grad` around JVP/VJP is not itself a bug. A tiny CPU check confirmed both `torch.func.jvp/vjp` and `torch.autograd.functional.jvp/vjp` still return derivatives inside surrounding `torch.no_grad()` blocks. + +## Highest-Value Next Static/CPU Checks + +1. Add a CPU-sized linearized test that compares `a_T` from `holo_a_track` against a direct solve of `J^T a = -ell` while sweeping `T2` and `hr`. Prediction: the 0.94-style error should move with `T2`, not `hr`. +2. In `asym_probe.py`, add a `--t2-sweep` diagnostic at fixed `hr`, and print `||a_t - a_{t-K}|| / ||a_t||` plus, if affordable, `||J^T a + ell|| / ||ell||`. +3. Re-run the existing beta/hr sweep only after proving the selected `a` is converged for each point. diff --git a/ep_run/adaptive_eps_calib.py b/ep_run/adaptive_eps_calib.py new file mode 100644 index 0000000..475e296 --- /dev/null +++ b/ep_run/adaptive_eps_calib.py @@ -0,0 +1,40 @@ +import torch, pickle, math +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=8; T=256 +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev) +def mkblk(): + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True + with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + return blk +def force_g(blk,z,xin): + f=blk.force(z,xin).detach(); return f, f.norm().item()/(z.norm().item()+1e-9) +def run(adaptive, e0=0.1, emin=0.005, emax=0.1, down=0.5, up=1.1, theta=0.98, N=10000): + blk=mkblk() + with torch.no_grad(): + xin=blk.embed(idx).detach(); z=xin.clone(); eps=e0; prev=None; gs=[]; eh=[] + for t in range(N): + f,g=force_g(blk,z,xin); gs.append(g); eh.append(eps) + if adaptive and prev is not None: + if g>theta*prev: eps=max(emin,eps*down) + elif g<0.9*prev: eps=min(emax,eps*up) + prev=g; z=z+eps*f + tail=gs[-500:] + return dict(gmin=min(tail), gmean=sum(tail)/len(tail), avg_eps=sum(eh)/len(eh), final_eps=eh[-1]) +print("=== adaptive-eps controller CALIBRATION on s3200 (cycling op) ===") +print("ground truth: fixed eps=0.1 cycles (g~0.23); fixed eps=0.01 converges (g~0.09)") +print("-- benchmarks (fixed eps) --") +for e in (0.1, 0.01): + r=run(False, e0=e, emin=e, emax=e); print(f" fixed eps={e}: g_tail[min={r['gmin']:.4f} mean={r['gmean']:.4f}]") +print("-- adaptive configs (want: g <= 0.01-benchmark, avg_eps as HIGH as possible = fewer effective steps) --") +for name,kw in [("C1 cons", dict(down=0.5,up=1.1,theta=0.98)), + ("C2 mod", dict(down=0.7,up=1.2,theta=0.98)), + ("C3 caut", dict(down=0.5,up=1.05,theta=0.99)), + ("C4 aggr", dict(down=0.6,up=1.3,theta=0.95))]: + r=run(True, **kw); + print(f" {name} {kw}: g_tail[min={r['gmin']:.4f} mean={r['gmean']:.4f}] avg_eps={r['avg_eps']:.4f} final_eps={r['final_eps']:.4f}") +print("=== DONE ===") diff --git a/ep_run/adaptive_eps_calib2.py b/ep_run/adaptive_eps_calib2.py new file mode 100644 index 0000000..ab78f7e --- /dev/null +++ b/ep_run/adaptive_eps_calib2.py @@ -0,0 +1,39 @@ +import torch, pickle +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=8; T=256 +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +CK={'s3200':torch.load('runs/redx_traj/s3200.pt',map_location=dev), + 's2000':torch.load('runs/redx_traj/s2000.pt',map_location=dev)} +def mkblk(name): + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True + with torch.no_grad(): + for p,w in zip(blk.allp,CK[name]['allp']): p.copy_(w.to(dev)) + return blk +def fg(blk,z,xin): + f=blk.force(z,xin).detach(); return f, f.norm().item()/(z.norm().item()+1e-9) +# corrected controller: shrink on OVERSHOOT (g rose), grow otherwise +def run(name, e0=0.05, emin=0.003, emax=0.1, up=1.05, down=0.7, tol=1.0, N=8000): + blk=mkblk(name) + with torch.no_grad(): + xin=blk.embed(idx).detach(); z=xin.clone(); eps=e0; prev=None; gs=[]; eh=[] + for t in range(N): + f,g=fg(blk,z,xin); gs.append(g); eh.append(eps) + if prev is not None: + if g > prev*tol: eps=max(emin, eps*down) # residual climbed -> eps too big + else: eps=min(emax, eps*up) # contracting -> grow for speed + prev=g; z=z+eps*f + tail=gs[-500:] + return dict(gmin=min(tail), gmean=sum(tail)/len(tail), avg_eps=sum(eh)/len(eh), final_eps=eh[-1]) +print("=== corrected adaptive-eps (shrink on OVERSHOOT) — calibrate on stiff + smooth ===") +print("target: s3200 converges (g~0.09) at avg_eps>0.005 (faster than naive); s2000 stays eps~0.1") +for name in ('s3200','s2000'): + print(f"-- {name} --") + for tag,kw in [("A up1.05 dn0.7", dict(up=1.05,down=0.7)), + ("B up1.1 dn0.5", dict(up=1.1,down=0.5)), + ("C up1.03 dn0.8 tol1.02", dict(up=1.03,down=0.8,tol=1.02))]: + r=run(name,**kw) + print(f" {tag}: g_tail[min={r['gmin']:.4f} mean={r['gmean']:.4f}] avg_eps={r['avg_eps']:.4f} final_eps={r['final_eps']:.4f}") +print("=== DONE ===") diff --git a/ep_run/alert.sh b/ep_run/alert.sh new file mode 100755 index 0000000..6fe5929 --- /dev/null +++ b/ep_run/alert.sh @@ -0,0 +1,13 @@ +#!/bin/bash +LOG=ep_run/runs/ep_resreg_warm.log +cd /home/yurenh2/ept +while true; do + sleep 900 + if [ -n "$(find "$LOG" -mmin +45 2>/dev/null)" ]; then echo "LOG STALE >45min (resreg_warm dead/stuck)"; break; fi + LAST=$(grep -E "val CE" "$LOG" | tail -1) + BEST=$(echo "$LAST" | grep -oE "best [0-9.]+" | grep -oE "[0-9.]+$") + EMA=$(echo "$LAST" | grep -oE "ema=[0-9.]+" | grep -oE "[0-9.]+$") + awk "BEGIN{exit !($BEST < 2.02)}" 2>/dev/null && { echo "NEW BEST <2.02 (full recovery + improvement): $LAST"; break; } + awk "BEGIN{exit !($EMA > 4.0)}" 2>/dev/null && { echo "RE-COLLAPSE ema>4: $LAST"; break; } +done +echo "FIRED: $LAST" diff --git a/ep_run/analogET_extracted.txt b/ep_run/analogET_extracted.txt new file mode 100644 index 0000000..b139640 --- /dev/null +++ b/ep_run/analogET_extracted.txt @@ -0,0 +1,1861 @@ + Dense Associative Memories with Analog Circuits + Marc Gong Bacvanski1 , Xincheng You2 , John Hopfield3 , and Dmitry Krotov4 + 1 + MIT + 2 + Independent Researcher + 3 + Princeton University + 4 + IBM Research + + December 16 2025 +arXiv:2512.15002v1 [cs.NE] 17 Dec 2025 + + + + + Abstract: The increasing computational demands of modern AI systems have exposed fundamental + limitations of digital hardware, driving interest in alternative paradigms for efficient large-scale inference. + Dense Associative Memory (DenseAM) is a family of models that offers a flexible framework for repre- + senting many contemporary neural architectures, such as transformers and diffusion models, by casting + them as dynamical systems evolving on an energy landscape. In this work, we propose a general method + for building analog accelerators for DenseAMs and implementing them using electronic RC circuits, cross- + bar arrays, and amplifiers. We find that our analog DenseAM hardware performs inference in constant + time independent of model size. This result highlights an asymptotic advantage of analog DenseAMs + over digital numerical solvers that scale at least linearly with the model size. We consider three settings + of progressively increasing complexity: XOR, the Hamming (7,4) code, and a simple language model + defined on binary variables. We propose analog implementations of these three models and analyze the + scaling of inference time, energy consumption, and hardware. Finally, we estimate lower bounds on the + achievable time constants imposed by amplifier specifications, suggesting that even conservative existing + analog technology can enable inference times on the order of tens to hundreds of nanoseconds. By har- + nessing the intrinsic parallelism and continuous-time operation of analog circuits, our DenseAM-based + accelerator design offers a new avenue for fast and scalable AI hardware. + + + 1 Introduction + The unprecedented growth of artificial intelligence (AI) has driven demand for increasingly large and + powerful models. At present, the field of generative AI is primarily driven by two settings: autore- + gressive transformers [1] and diffusion models [2]. While these settings have demonstrated remarkable + capabilities, they do so at a substantial computational cost. Their current implementations utilize digital + computation, which faces fundamental challenges in energy efficiency, scalability, and latency, especially + as model sizes and deployment demands continue to grow [3, 4, 5]. These limitations have prompted + interest in alternative computational paradigms that can efficiently handle the demands of modern AI + workloads [6]. + Dense Associative Memories (DenseAMs) [7, 8], a promising class of AI models which generalize + Hopfield networks [9], offer a new angle for tackling these problems. Unlike conventional feed-forward + models, DenseAM inference can be defined through the temporal evolution of a state vector that is + governed by a system of differential equations [10]. The state vector can be thought of as a particle + exploring the surface of a high-dimensional energy landscape, which is the Lyapunov function of these + dynamical equations. DenseAMs have been demonstrated to be flexible and expressive computational + frameworks, capable of representing many primitives of modern AI architectures, such as attention + mechanism [11], transformers [12], and diffusion models [13, 14, 15]. Furthermore, DenseAMs are error- + correcting systems [16], a property ensuring that small perturbations of the desired temporal evolution + of the state vector are corrected away by the dynamics of the network itself, rather than accumulated + in time. Finally, DenseAMs are asymptotically stable—during the course of temporal evolution the + computation happens during a finite transient period of time, which is followed by a steady state of + Code available at https://github.com/mbacvanski/AnalogET. + + + + 1 +neural activities. This asymptotic stabilization of dynamical trajectories removes the requirement to read +out the “answer” to the computation problem at a precise moment of time, making DenseAMs robust +to several classes of hardware imperfections. The confluence of the above properties makes DenseAMs +appealing networks for analog hardware implementations that, on the one hand, are grounded in the +physics of stable error-correcting dynamical systems and, on the other hand, are capable of representing +computation in state-of-the-art AI networks. + In 1989, Hopfield argued that analog neural hardware can exceed the efficiency of digital implemen- +tations when the device physics directly instantiate the computational dynamics of the model itself [17]. +Here, we revisit this idea with DenseAM models: we propose an analog circuit-based hardware accel- +erator design whose dynamics directly realize those of the DenseAM. We find that analog DenseAM +hardware enables constant-time inference independent of model size, which is in stark contrast to GPU +solvers and digital implementations. This intrinsic property makes DenseAM a natural fit for analog AI +accelerators, and it highlights our circuit architecture as a viable hardware path to realize them. Using +component specifications already demonstrated in fabricated devices, analog DenseAM hardware may +achieve inference times on the order of tens to hundreds of nanoseconds, several orders of magnitude +faster than digital systems. + By leveraging the natural dynamics of analog systems, this work establishes a new design of fast and +scalable AI accelerators. The framework of DenseAMs and their efficient analog hardware implementa- +tions suggest a pathway for fundamentally redesigning the hardware-software interface for AI, enabling +a new paradigm for fast, energy-efficient, and scalable computation. + + +2 Dense Associative Memory basics +The DenseAM framework [10, 18] provides a model that has straightforward neuronal dynamics, yet is +surprisingly expressive in its ability to represent AI models including transformer attention, diffusion +models, and associative memories. In its simplest form it is defined by two sets of neurons (typically +called visible and hidden neurons) and a system of coupled non-linear differential equations governing +their behavior, see Figure 1. The visible neurons are characterized by their internal states vi and their +outputs gi , index i = 1 . . . Nv ; while the hidden neurons have internal states hµ and outputs fµ , index +µ = 1 . . . Nh . From the AI perspective, one can think about internal state of the neuron as a pre-activation +of that neuron, and the output as a post-activation, which is obtained by applying an activation function +to the pre-activation. From the biological perspective, one can think about the internal state of the +neuron as a membrane voltage potential, and the output of that neuron as an axonal output, or a firing +rate of that neuron. This framework admits both neuron-wise activation functions (gi = g(vi ), where +g(·) is some continuous function, e.g., a ReLU), and collective activation functions such as softmax or +layer normalization, which depend on the states of multiple neurons. + The network parameters are stored in the synaptic weights ξ ∈ RNh ×Nv , whose matrix elements +denoted by ξµi can be either hand-engineered or learned. The time decay constants for the two groups +of neurons are τv and τh . With these conventions, the temporal evolution of the two groups of neurons +can be expressed as Nh + dvi X + τ = ξµi fµ + ai − vi + + v dt + + + + µ=1 + (1) + Nv + dh + + µ + X + τh dt = ξµi gi + bµ − hµ + + + + i=1 + +This forms a bipartite graph of neuronal connections, where the state of the hidden neurons is updated +by the state of the visible neurons, and vice versa. Importantly, the same matrix ξ appears in both +equations, once as ξ and again as ξ ⊤ . Although this is sometimes described as using “symmetric” +weights, ξ is not assumed to be symmetric in the linear-algebraic sense; it is simply the same matrix +used in both directions. Finally, ai and bµ denote biases, which are additional weights of the system and +whose values may be hard-coded or learned depending on the application. + The most important aspect of this model is the existence of a global energy function (Lyapunov +function) that describes neuronal dynamics. To demonstrate this, it is most convenient to use the +Lagrangian formalism [10, 18, 16]. Each set of neurons is defined through a Lagrangian function of their +internal states. The activation functions are defined as partial derivatives of that Lagrangian with respect +to internal states. The total energy is the sum of energies of each set of neurons, plus the interaction + + + + 2 +Figure 1: Top left: Bipartite neural network formulation, where hidden neurons hµ and visible neurons +vi are connected via symmetric synaptic weights ξ. Top right: Circuit realization of symmetric weight +matrix via resistive crossbar array. Each crosspoint encodes a weight ξµi by its resistance Rµi = 1/ξµi . +Lower right: Circuit schematic of a single hidden neuron. It drives its row of the crossbar array with +a voltage according to its activation fµ , and its internal dynamics are driven by the incoming current +flowing into it from the crossbar array. Lower left: Softmax activation function built from bipolar +junction transistors (some components not shown). + + +energy. The energy of each set of neurons is a Legendre transformation of the corresponding Lagrangian +(plus the term proportional to the bias). Thus, the global energy of Equation 1 is given by + Nv + X Nh + X Nh X + X Nv + E= gi (vi − ai ) − Lv + fµ (hµ − bµ ) − Lh − fµ ξµi gi (2) + i=1 µ=1 µ=1 i=1 + | {z } | {z } | {z } + energy of visible neurons energy of hidden neurons interaction energy + +where the activation functions are defined as partial derivatives of the Lagrangians + ∂Lv ∂Lh + gi = , fµ = + ∂vi ∂hµ +For convex Lagrangians this global energy decreases with time on the dynamical trajectories of Equa- +tion 1. If, additionally, the activation functions (and corresponding Lagrangians) are chosen in such a +way that this energy is bounded from below, the dynamical trajectories are guaranteed to arrive at a +stable fixed point of activations. The dynamical equations typically have many asymptotic fixed points, +which correspond to local minima of the energy function in Equation 2. Both properties above (convexity +of Lagrangians and lower-bounded energy) are satisfied for all settings studied in this paper. By picking +different nonlinear activation functions (or corresponding Lagrangians), this system yields a variety of +models that can describe associative memory, softmax attention, and other commonly used settings in +AI [10, 11, 18, 19, 20]. + A particularly relevant example for modern sequence modeling is the Energy Transformer (ET) [12], +which reformulates transformer’s inference pass as a gradient flow on an energy function defined over the + + + 3 +set of tokens. The ET block contains two contributions to the energy function: attention energy and the +Hopfield network. The energy attention module routes the information between the tokens, while the +Hopfield module aligns the tokens with the manifold of token embeddings. In our implementation, the +context tokens act as a set of dynamically instantiated memories that interact with the predicted token +through a DenseAM-like energy. In section 6 we exploit this connection to construct an Analog Energy +Transformer (Analog ET) whose continuous-time dynamics are implemented directly in hardware using +our DenseAM circuit primitives. + + +3 Related work +Early analog implementations of associative memories focused on the classical Hopfield network. Founda- +tional designs, such as continuous-time analog circuits [21, 22] and later demonstrations using amorphous- +silicon resistors [23], memristive devices [24, 25], and phase-change memories [26], targeted the quadratic +Hopfield energy function. These works emphasize device engineering and memory-cell design rather than +system-level dynamics, and inherit the limited storage capacity and representational power of traditional +Hopfield networks. That line of research is largely concerned with how to fabricate programmable re- +sistance elements themselves; our work assumes programmable conductances as a given primitive and +focuses on the continuous-time dynamics that operate on top of them. Our work also differs from these +works by addressing DenseAMs with higher-order energy functions and continuous-valued states. + Another direction is the use of cavity-QED systems for associative memory. Marsh et al. [27] analyze +a confocal cavity implementation of a quadratic Hopfield network and show that the cavity dynamics +induce a descent-like relaxation rule on spin states. Their model remains restricted to quadratic energies +and binary spins, and operates in a cryogenic, cavity-QED setting. Our work instead targets higher-order +DenseAMs with continuous states, and emphasizes scalable, room-temperature analog microelectronics +with explicit hardware-aware dynamical analysis. + More recent physical implementations move beyond purely quadratic energies. Musa et al. [28] +propose a free-space optical realization of the higher-order DenseAM energy. Their system constructs a +static physical representation of the energy landscape, but inference relies on an external digital controller +that performs iterative spin-flip updates. Thus, the hardware computes energies, while the optimization +dynamics remain digital. In contrast, our analog microelectronic architecture embeds the gradient flow +itself into hardware: inference is performed by a single continuous-time evolution rather than by discrete +digital updates. + + +4 DenseAM circuit design +Here, we introduce a novel architecture for a class of analog electronic hardware accelerators that models +Equation 1’s system of nonlinear differential equations using time evolution. Our DenseAM design +shown in Figure 1 is comprised of two sets of neurons that interact through a resistive crossbar array. +The resistive crossbar array turns voltage differences between neurons into currents flowing between the +neurons according to synaptic weights, and each neuron’s internal circuitry converts those currents into +dynamics that reproduce Equation 1. + +Resistive weights as a crossbar array. The crossbar array construction is a canonical design of +matrix-vector multiplication using analog electronics [17, 29], and is a natural fit for the weight matrix +ξ in our model. Traditionally, each crosspoint between a row and column line is connected by a resistor +(often memristors, RRAM, or other analog memories that produce resistances), a vector of input voltages +is applied at row lines, and the column lines are held at ground typically via a transimpedance amplifier. +By Ohm’s law, each resistive crosspoint produces a current that multiplies the row’s input voltage by +the inverse of the resistance. Because currents add along each column line, the total current output at a +column is the inner product between the vector of input voltages and the column’s conductance vector. +Thus, the array as a whole implements a parallel analog matrix multiplication of the form Iout = GVin , +where G is the matrix of conductances (inverse of resistances). + Unlike a traditional crossbar array whose rows are driven at a fixed voltage and whose columns +are held at ground, our DenseAM circuit design uses each weight bidirectionally, exactly representing +the bidirectional connections between visible and hidden neurons. As a result, the current flowing into +each neuron corresponds to the weighted sum of the differences P between visible and hidden neuron +activations. For example, for hidden neuron µ, this current is i ξµi (gi − fµ ). This construction enables + + + 4 + (1, 0) (1, 1) + 1 g3 0.4 + Neurons + Visible + + + + + Energy + 0.2 + 0 + + 1 f3 0.0 + Neurons + Hidden + + + + + (0, 0) (0, 1) + 0 0.4 + + + + + Energy + 0.5 + Energy + + + + + 0.2 + + 0.0 0.0 + 0.0 0.5 1.0 1.5 2.0 2.5 3.0 + 0 1 0 1 + Time (s) + v3 v3 + +Figure 2: Solving XOR with a DenseAM. Visible Figure 3: XOR energy landscape of neuron v3 un- +neuron g3 = v3 serves as the output, while the two der different settings of visible input neurons v1 and +input neurons (unlabeled, thin lines) are clamped v2 . Minima in the energy function correspond to +at 1 and 0 for True and False. Output v3 is initial- stationary points of the dynamics. Gradient flow +ized at 0.5 and converges to a positive prediction of dynamics bring v3 to these attractor points, result- +1. The activation of the hidden neuron f3 for the ing in correct XOR outputs. +truth-table row (1, 0, 1) becomes highly activated, +with others (fine lines) are suppressed by softmax. +Energy (2), or equivalently (5), decreases monoton- +ically along the inference trajectory. + + +weight symmetry to be enforced by hardware sharing: both forward and reverse weights are realized by +the same resistive elements. Importantly, as long as weights are represented as conductances, they must +be non-negative. + +Design of a single neuron. Each neuron in the circuit computes its dynamics by integrating the cur- +rents it receives from the crossbar array, which represent weighted differences between its own activation +and those of connected neurons. Considering a hidden neuron (the design for visible neurons is symmet- +ric by design), the neuron’s internal voltage hµ is stored on capacitor C1 and evolves in continuous time, +while the neuron’s activation fµ is obtained by passing hµ through a nonlinear function (e.g. ReLU or +softmax). + The current flowing into hidden neuron µ is produced by its interaction with all visible neurons via +the synaptic weights ξµi for P i = 1, . . . , Nv . Specifically, this is as a weighted sum of the differences +between neuron P activations: i ξµi (gi − fµ ). Inside each neuron, a “self” path scales fµ to produceP the +voltage sµ = fµ i ξµi . This term is added to the value of the incoming current so that the −fµ i ξµi +term is cancelled inside each neuron. As a result, the hidden state, represented as the voltage across +capacitor C1 , integrates only the desired weighted input plus any external stimulus bµ . Its dynamics +reduce to the canonical DenseAM form with a time constant of R2 C1 : + Nv + dhµ X + R2 C 1 = ξµi gi + bµ − hµ (3) + dt i=1 + +Elementwise (or vectorized) nonlinearities then produce activations gi = g(vi ) and fµ = f (hµ ) (e.g., +ReLU, softmax) across the visible and hidden neurons. See Appendix A for the full circuit derivation. + + +5 Analog DenseAM Examples +We begin by studying two examples of the proposed design: the XOR task, and the (7,4) error-correcting +Hamming code. + + + + + 5 +5.1 XOR +The XOR problem is a canonical test for nonlinear representation and inference, as it cannot be solved +by any linear model. We show a minimal DenseAM model for the XOR task, illustrating how energy- +based dynamics can solve this simple task with a continuous-time analog system. The network consists +of Nv = 3 visible neurons, and Nh = 4 hidden neurons. At t = 0 visible neurons v1 and v2 are initialized +at their input values corresponding to the input bits. The last visible neuron v3 is initialized at v3 = 0.5. +The hidden neurons are initialized at zero. The two input visible neurons remain clamped during the +dynamics, while the third output visible neuron and the hidden neurons evolve in time according to (1). +Each row of the memory matrix ξ corresponds to a row of the XOR truth table. The visible neurons +use an identity activation function where gi = vi , and the hidden neurons use a softmax activation. The +biases are set as + N v + 1X 2 + ai = 0, bµ = − ξµi + 2 i=1 + + Figure 2 shows the temporal evolution of visible and hidden neuron activations, as well as the total +energy, during inference on the XOR input (1, 0). The output visible neuron’s activation g3 gradually +converges to the correct prediction of 1, while the hidden neuron associated with that memory, f3 , +becomes strongly activated and the remaining hidden neurons are suppressed by the softmax nonlinearity. +The system’s energy decreases monotonically throughout the trajectory and stabilizes once the network +reaches its fixed-point prediction. Figure 3 depicts the system’s energy landscape as a function of output +neuron v3 for different clamped input configurations (v1 , v2 ). In each case, the energy exhibits a clear +convex minimum at the correct XOR output, demonstrating that gradient flow along the energy surface +drives v3 reliably toward the correct prediction. As shown in Appendix C, we validate our circuit design +and dynamics using SPICE simulation. + τh → 0. Since the second equation in + To analyze this DenseAM, it is instructive to consider the limit P + Nh +(1) is linear in hidden units hµ , they can be integrated out. With µ=1 fµ = 1, the resulting dynamics +of the visible neurons can be written as + Nh Nv + dvi X βX + (ξµi − vi )2 + + τv = ξµi − vi fµ where fµ = softmax − (4) + dt µ=1 + 2 i=1 + +The effective energy on the visible neurons can be written as + Nh Nv + 1 X h βX i + E eff (v) = − log exp − (ξµi − vi )2 (5) + β µ=1 + 2 i=1 + +Intuitively, each hidden neuron computes a squared Euclidean distance between the visible state and its +stored pattern ξ µ . The softmax nonlinearity assigns higher weight to the pattern closest to the current +state of the visible neurons. The resulting visible neuron dynamics are gradient flow for this effective +energy. It is important to note that memories in this implementation are represented by conductances +of the crossbar array, which are always positive. For this reason, matrix elements of memories ξµi must +be positive, necessitating the use of the bias terms, which are just voltage sources that can be arbitrarily +signed. + While a time constant of τh = 0 is impossible to physically construct due to finite conductances +and nonzero capacitances, setting τh ≪ τv realizes the same adiabatic limit in practice. When hidden +neurons evolve much faster than visible ones, they reach their steady state almost instantaneously for each +configuration of visible neurons. The result is an adiabatic elimination of hidden dynamics, yielding the +effective visible-only dynamics above. In practice, for the XOR task, even a relatively modest τh = τv /10 +ratio yields perfect performance. + +5.2 Hamming (7,4) code +The Hamming (7,4) code is an error-correcting code that encodes 4 data bits into a 7-bit codeword by +adding 3 parity bits. The resulting 7-bit strings are special: only certain patterns are valid codewords, +and they are spaced apart so that if a single bit is flipped, the error can be detected and corrected [30]. +Table 1 lists the 16 codewords corresponding to four arbitrary data bits. + + + 6 + 1 + g5 + Neurons + Visible + Data bits (d1 d2 d3 d4 ) Codeword (c1 c2 c3 c4 c5 c6 c7 ) + + 0 + 0000 0000000 + 0001 0001111 + 1 f7 0010 0010110 + Neurons + Hidden + + + + + 0011 0011001 + 0100 0100101 + 0 + 0101 0101010 + 0.5 0110 0110011 + Energy + + + + + 0111 0111100 + 1000 1000011 + 0.0 1001 1001100 + 0 1 2 3 4 5 + 1010 1010101 + Time (s) + 1011 1011010 + 1100 1100110 + 1101 1101001 +Figure 4: Correcting a bit error in a Hamming 1110 1110000 +(7,4) code. Visible neuron g5 flips indicating the 1111 1111111 +bit flip error happened on the 5th codeword bit. f7 +is the hidden neuron corresponding to the memory Table 1: Valid codewords of the Hamming(7,4) +of the correct codeword. Thin lines correspond to code, ordered by their 4-bit data payload. +the other neuron activations. + + + Unlike the XOR case where the only evolving neuron is the readout bit, the Hamming (7,4) code may +require flipping the value of any one of the visible neurons. During inference, the visible neurons are +initialized to the corrupted 7-bit input word. All neurons are left free to evolve, and the dynamics relax +the state toward the nearest stored codeword. Energy minima are located at the valid codewords, so the +network converges to the correct code provided the error is within the Hamming radius of 1. Thus, the +DenseAM replicates the standard decoding property of the Hamming (7,4) code: any single-bit flip is +corrected automatically. Figure 4 illustrates a case where a flipped bit g5 is restored during convergence. + The Hamming (7,4) model’s 7 visible neurons, each corresponding to a codeword bit, are connected +to 16 hidden neurons, each representing one valid codeword. The weight matrix ξ ∈ {0, 1}16×7 is formed +by stacking the 16 codewords as its rows. Visible neurons have the identity activation, hidden neurons +use a softmax activation, and biases are chosen as in the XOR case to give the same integrated-out +visible dynamics as (4). + + +6 Analog Energy Transformer (Analog ET) via DenseAM +Our DenseAM circuit construction can be used to build more complex energy-based models, such as +the transformer-like architecture proposed in the Energy Transformer paper [12]. For causal next-token +prediction with a single attention head, the Energy Transformer’s energy function can be written as the +following (See Appendix J for full derivation): + ⊤ ⊤ ⊤ attn ⊤ hopf + E = 12 ∥v − a∥2 − v⊤ ξ attn f attn + ξ hopf f hopf + f attn − b + f hopf + + h h −c + − Lattn hattn − Lhopf hhopf + + (6) + +with the activation functions and their Lagrangians defined as + L + X + fAattn = softmax(βhattn )A , Lattn (h) = β1 log eβhA (7) + A=1 + M h + X i2 + fµhopf = ReLU(hhopf + µ ), Lhopf (h) = 21 ReLU(hµ ) (8) + µ=1 + +where a, b, and c correspond to the biases of the visible neurons, attention hidden neurons, and Hopfield +network hidden neurons, respectively. The L context tokens are indexed by A, and the M hidden neurons +of the Hopfield network are indexed by µ. Because the visible units use an identity activation function, + + + 7 +Figure 5: Analog ET circuit demonstrating the autoregressive inference procedure. A newly inferenced +token is decoded, sampled, and re-embedded to obtain the weight vector ξ attn + L+1 , which is set as the weight +vector for a new hidden neuron hattn + L+1 in the energy attention block (light gray on right). For this layout +we have flipped the crossbar array, so that indices A and µ run horizontally and index i runs vertically. + + +gi = vi using the languge of Equation 1, the gradient flow of the energy yields the dynamics: + ∂E ⊤ ⊤ + τv v̇ = − = ξ attn f attn + ξ hopf f hopf + a − v (9) + ∂v + ∂E + τh ḣattn + = − attn = ξattn v + b − hattn (10) + ∂f + ∂E + τh ḣhopf = − hopf = ξhopf v + c − hhopf (11) + ∂f +In this formulation, v represents the embedding of the output (next) token, and its evolution is driven by +two terms: one term from the energy attention with weights ξattn and hidden neuron activations f attn , +and one term from the Hopfield network with weights ξ hopf and hidden neuron activations f hopf . The +weights of the energy attention DenseAM are dependent on the context: for a token dimension D, context +length L, and the task of predicting the token at index L + 1, the weights ξ attn ∈ RL×D are generated +by embedding each token of the context via a learned embedding matrix applied to each context token. +In contrast, the Hopfield network weights ξ hopf are learned during training and fixed at inference. The +number of memories in the Hopfield network is a hyperparameter M , such that ξ hopf ∈ RM ×D . + This system suggests a hardware implementation where v interacts with two independent DenseAMs, +one for the energy attention and one for the Hopfield term, which can share the same physical crossbar +structure. Figure 5 shows that the circuit structure remains a crossbar array (like Figure 1), but with +two distinct classes of hidden neurons. Because of the summation of currents along each row of the +crossbar array, the incoming current to visible neuron vi is the sum of contributions from the energy +attention block and from the Hopfield network block. The energy attention hidden neurons hattn use a +softmax activation function, while the Hopfield network hidden neurons hhopf use a ReLU activation. + +6.1 Analog Energy Transformer on the parity task +We build and evaluate the Analog ET on the L-bit parity task, which can + P be thought of as an elementary + L +“language model”: given bits bit1 , . . . , bitL , predict bitL+1 = A=1 bitA mod 2. Parity is instructive +because it requires a representation of a global, order-L interaction, precluding linear and shallow models +from representing it efficiently. A successful model must be able to form high-order interactions in order +to generalize. We formulate parity as a next-token prediction problem: given an L-bit string as context, +predict its parity in the next token. + We train the Analog ET model digitally using backpropagation through time [31] implemented with +Jax’s automatic differentiation. The resulting weights can be deployed onto the analog hardware; in + + + 8 + 11001010 0 01000110 1 + + 4 +Visible neurons + + + 2 + 0 + 1 +Prediction + + + + + 0 + 10 +Energy + + + + + 20 + 30 + 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 + Time t Time t +Figure 6: Inference of parity Analog ET on two example 8-bit strings. Top row plots the visible neurons vi +over time, middle row plots the decoded token prediction, bottom row plots the energy that monotonically +decreases during inference. After a transient period of computation, the network arrives at a steady- +state, making the result of the computation robust against the precise timing of the readout. + + +our experiments we simulate the dynamics of hardware with the Diffrax [32] ODE solver library. On +the 8-bit parity task, our model achieves 100% accuracy on the hold-out validation set of 52 bit strings, +demonstrating clear generalization capabilities. See Appendix H.1 for more details on training and model +design. + Figure 6 shows the dynamics of the visible neurons and energy during two example inference runs +of the Analog ET. Notably, the visible neuron values are constant by the end of the inference period, +meaning that the inference remains highly stable to mismatch and delay in timing during readout. A +single sample-and-hold and switching circuit would enable a single Analog-Digital Converter (ADC) to +read out all the visible neurons at convergence, significantly reducing mismatch, and drastically saving +device area, complexity, and energy. The intrinsic stability of attractor points arises uniquely from +the continuous-time dynamics of the DenseAM, making these models particularly well suited to analog +hardware. + +6.2 Autoregressive inference +Dashed lines in Figure 5 illustrate the autoregressive inference procedure of the Analog ET. To generate +the L-th token given context tokens x(1) , . . . , x(L−1) , each token is first embedded and concatenated to +form the attention weight matrix + (1) + e + e(2) + ξ attn,(L−1) = . ∈ R(L−1)×D + + .. + e(L−1) + +These rows are loaded into the Analog ET’s energy attention weight matrix ξ attn by programming the +corresponding crossbar resistances. During inference, the visible state v(t) evolves according to the +Analog ET dynamics until convergence. A decoder readout (e.g., a linear layer) applied to the converged +v(t = T ) values produces logits, from which the next token x(L) is sampled. This token is then embedded +to form e(L) , and appended to the existing context. The cycle repeats with the updated attention weight + + + 9 +matrix + attn,(L−1) + ξ + ξ attn,(L) = ∈ RL×D + e(L) + +which now includes the new embedding e(L) . In hardware, this corresponds to connecting an additional +hidden neuron in the energy attention block of Figure 5, and setting its resistive weights with e(L) . +Because the physical order of hidden neurons does not affect the energy function, this new neuron can +be placed in any position among the hidden neurons. When the context length is fixed, the hidden +neuron corresponding to the earliest token can simply be reprogrammed with the new vector of weights +e(L) , resulting in the hardware equivalent of a sliding-window context. In practice, an external digital +controller, e.g., an Field-Programmable Gate Array (FPGA) or Application-Specific Integrated Circuit +(ASIC) would orchestrate crossbar programming and token decoding, while the DenseAM dynamics +perform the far more substantial workload of computing each next-token embedding. + This procedure is analogous to key-value (KV) caching in standard transformer inference [33]. Context +tokens x(1) , . . . , x(L−1) produce key and value vectors k(1) , . . . , k(L−1) and v(1) , . . . , v(L−1) respectively. +When new token x(L) is generated, its corresponding k(L) and v(L) vectors are appended to the cache, +allowing all previous k(<L) and v(<L) to be reused without recomputation. When the key and value +matrices are tied so that k(A) = v(A) , the ET’s row-append operation is equivalent to the standard KV- +cache update. The ET performs an autoregressive rollout that reproduces the same recurrence structure +as KV-cached transformer inference, but implemented physically through the addition of new neurons +and weights without touching existing hardware. For a formal derivation of the equivalence between ET +attention and conventional attention with tied keys and values, see [12]. + + +7 Scaling properties +Inference time and energy consumption are crucial characteristics of our system. This section investigates +these metrics with respect to the network size. + +7.1 Inference time scaling +The model (4) and (5) is considered. In the adiabatic limit (τh → 0), which is satisfied by our hardware +implementation, the time derivative of the energy can be written as + Nv Nv + dE eff X ∂E eff dvi 1 X ∂E eff 2 Nv + = =− ∼− (12) + dt i=1 + ∂vi dt τv i=1 ∂vi τv + +This derivative is always negative, since the dynamical system performs the gradient descent on the +energy landscape. The derivative vanishes eventually when the network state vector v converges to the +steady state. Since the state vector vi is typically initialized in the vicinity of the memory vectors, which +are chosen to be of order one (∼ 1), the right hand side of (4) is of order one too, independent of the +network size. This results in the characteristic value of the temporal derivative shown in (12). + At the same time, the typical value1 of the energy (5) is + 1 + |E eff | ∼ Nv + log(Nh ) (13) + β +During the inference dynamics the network is initialized in a high energy state, which has the charac- +teristic value of energy (13), and performs energy descent to a lower value of the energy (which has a +similar order of magnitude). In order to estimate the scaling of the time required to perform this energy +descent, one can take a ratio of the energy drop by the rate of the energy decrease (12). This gives the +following estimate + |E eff | 1 log(Nh ) + T conv ∼ ∼ τv 1 + ∼ τv (14) + dE β Nv + dt + +The last ∼ sign holds since in none of the designs presented here does Nh grow super-exponentially in +Nv . In fact, in all the use cases Nh is sub-exponential in Nv . + 1 We estimate the absolute value of the energy, since it can be both positive and negative depending on the mutual + +arrangement of memories, the state vector, and the number of hidden units. + + + 10 + This back-of-the-envelope estimation provides the core intuition behind the scaling relationship. +The inference time is constant, and independent of the size of the network. A more careful anal- +ysis (Appendix E) shows that in the high-β regime the worst-case dependence is O τβv logNNv + h + , which +remains bounded for all architectures we consider. Thus, for our settings the convergence time is ef- +fectively constant in Nv and Nh . Based on amplifier gain–bandwidth, slew-rate, and output-current +constraints, we estimate achievable inference times of tens to hundreds of nanoseconds using existing +CMOS technology (see Appendix I.2). + +7.2 Scaling of energy consumption +We now analyze how the total inference energy scales with network size. Energy dissipation arises +primarily from (i) Ohmic loss in the resistive weights, (ii) charging of neuron-state capacitors, and (iii) +constant per-neuron overhead from amplifiers and bias currents. We show that, under bounded voltage +swings and fixed conductance budgets, total energy grows only linearly with the number of neurons. + +Weight dissipation. Let the neuron output voltages be proportional to activations: u = κg and +w = κf , where κ is a fixed voltage swing. Such a bounded swing can always be enforced by global +rescaling of ξ, β, and voltage units without changing the dynamics (see Appendix F). The instantaneous +power dissipated by the resistive crossbar array is + Nh X + X Nv + Pweights (t) = ξµi (ui − wµ )2 (15) + µ=1 i=1 + P P +With 0 ≤ gi ≤ 1, f -softmax, and row/column conductance budgets µ ξµi ≤ Cc , i ξµi ≤ Cr , the total +power obeys + + Pweights (t) ≤ 2κ2 (Cc Nv + Cr ) = O(Nv ) (16) + +For a runtime of duration T ∼ T conv , the energy dissipated by the weights is therefore Eweights = O(Nv T ), +where T ∼ 1 from subsection 7.1. + +Capacitive and overhead energy. Each neuron charges a local capacitor a finite number of times +by at most Vswing ∼ κ, giving + ! + (v) + X X + Ecap ≤ κ2 Ci + Cµ(h) = O(Nv + Nh ) (17) + i µ + +Active bias and amplifier inefficiencies contribute fixed per-neuron power, yielding Eother = O((Nv + Nh )T ). + +Total energy scaling. With bounded voltage swing and conductance budgets, + + Etotal = O(Nv + Nh ) (18) + +Hence, the total inference energy scales only linearly with system size. For the full derivation, see +Appendix G. + +7.3 Scaling of hardware area +The area is dominated by two components: the area taken up by the synaptic weights, which is imple- +mented as a crossbar array with programmable weights, and the area taken up by the neurons feeding +the crossbar array. The area of the crossbar array scales as the number of weights O(Nv Nh ). The area +of the neurons scales as O(Nv + Nh ). + + +8 Conclusion +In this paper, we have presented an analog accelerator architecture for Dense Associative Memories, +implemented using resistive crossbar arrays and continuous-time RC neuron dynamics. Our design im- +plements DenseAM inference as time evolution of a physical dynamical system, rather than a sequence of + + + 11 +discrete numerical update steps. We demonstrated this architecture with three representative settings of +increasing complexity: XOR, Hamming (7,4) error decoding, and an Energy Transformer-style sequence +model. These examples show that the analog DenseAM accelerator architecture covers both associative +memory tasks and attention-based sequence models. + Our analysis shows that DenseAM accelerators enjoy favorable asymptotic scaling properties. In- +ference time is constant in the dimensions of the model size, meaning that inference time is governed +primarily by the physical time constants of the circuit. This is in sharp contrast to digital implementa- +tions of the same dynamics, whose runtime must grow at least linearly with model size. + To assess hardware feasibility, we derived lower bounds on the neuronal time constants imposed by +amplifier gain-bandwidth product, slew rate, and output current limits in our neuron design. Reported +figures from representative CMOS OTAs in the literature give inference times on the order of tens-to- +hundreds of nanoseconds, even with conservative design margins. Combined with the constant scaling of +inference with model size, these estimates suggest that DenseAM accelerators can match or exceed the +latency of digital GPUs as models grow, without requiring exotic devices or beyond-CMOS technologies. + Our results highlight DenseAMs as a natural abstraction for analog AI hardware. Their error cor- +recting dynamics and asymptotic stability directly address long-standing concerns about robustness and +readout timing: small perturbations are corrected by the dynamics instead of accumulated, and the final +state is stable when readout happens over a wide temporal window. At the same time, the DenseAM +framework is expressive enough to capture modern primitives such as attention and transformer-like ar- +chitectures, as illustrated by our Analog Energy Transformer construction. These properties suggest that +DenseAM-based analog accelerators may be a promising substrate for future AI systems, and motivate +further co-design of models, dynamics, and devices. + +Acknowledgements +MGB would like to thank Faiz Muhammad for exploratory attempts at SPICE simulations. DK would +like to thank Kwabena Boahen for helpful discussions. + + +References + [1] Ashish Vaswani. “Attention is all you need”. In: arXiv preprint arXiv:1706.03762 (2017). + [2] Jascha Sohl-Dickstein et al. “Deep unsupervised learning using nonequilibrium thermodynamics”. + In: International conference on machine learning. pmlr. 2015, pp. 2256–2265. + [3] Norman P Jouppi et al. “In-datacenter performance analysis of a tensor processing unit”. In: + Proceedings of the 44th annual international symposium on computer architecture. 2017, pp. 1–12. + [4] Eric Masanet et al. “Recalibrating global data center energy-use estimates”. In: Science 367.6481 + (2020), pp. 984–986. + [5] David Patterson et al. “Carbon emissions and large neural network training”. In: arXiv preprint + arXiv:2104.10350 (2021). + [6] Maxwell Aifer et al. “Solving the compute crisis with physics-based ASICs”. In: arXiv preprint + arXiv:2507.10463 (2025). + [7] Dmitry Krotov and John J Hopfield. “Dense associative memory for pattern recognition”. In: + Advances in neural information processing systems 29 (2016). + [8] Dmitry Krotov and John Hopfield. “Dense associative memory is robust to adversarial inputs”. In: + Neural computation 30.12 (2018), pp. 3151–3167. + [9] John J Hopfield. “Neural networks and physical systems with emergent collective computational + abilities.” In: Proceedings of the national academy of sciences 79.8 (1982), pp. 2554–2558. +[10] Dmitry Krotov and John J Hopfield. “Large Associative Memory Problem in Neurobiology and + Machine Learning”. In: International Conference on Learning Representations. 2021. +[11] Hubert Ramsauer et al. “Hopfield networks is all you need”. In: arXiv preprint arXiv:2008.02217 + (2020). +[12] Benjamin Hoover et al. “Energy transformer”. In: Advances in Neural Information Processing + Systems 36 (2024). + + + + 12 +[13] Benjamin Hoover et al. “Memory in plain sight: A survey of the uncanny resemblances between + diffusion models and associative memories”. In: arXiv preprint arXiv:2309.16750 (2023). +[14] Luca Ambrogioni. “In search of dispersed memories: Generative diffusion models are associative + memory networks”. In: arXiv preprint arXiv:2309.17290 (2023). +[15] Bao Pham et al. “Memorization to generalization: Emergence of diffusion models from associative + memory”. In: arXiv preprint arXiv:2505.21777 (2025). +[16] Dmitry Krotov et al. “Modern methods in associative memory”. In: arXiv preprint arXiv:2507.06211 + (2025). +[17] JJ Hopfield. “The effectiveness of analogue’neural network’hardware”. In: Network: Computation + in Neural Systems 1.1 (1990), p. 27. +[18] Dmitry Krotov. “Hierarchical associative memory”. In: arXiv preprint arXiv:2107.06446 (2021). +[19] Fei Tang and Michael Kopp. “A remark on a paper of krotov and hopfield [arxiv: 2008.06996]”. In: + arXiv preprint arXiv:2105.15034 (2021). +[20] Benjamin Hoover et al. “A universal abstraction for hierarchical hopfield networks”. In: The Sym- + biosis of Deep Learning and Differential Equations II. 2022. +[21] John J Hopfield. “Neurons with graded response have collective computational properties like those + of two-state neurons.” In: Proceedings of the national academy of sciences 81.10 (1984), pp. 3088– + 3092. +[22] David W Tank and John J Hopfield. “Simple “Neural” optimization networks: an A/D converter, + signal decision circuit, and a linear programming circuit”. In: Artificial neural networks: theoretical + concepts. 1988, pp. 87–95. +[23] HP Graf et al. “VLSI implementation of a neural network memory with several hundreds of neu- + rons”. In: AIP conference proceedings. Vol. 151. 1. American Institute of Physics. 1986, pp. 182– + 187. +[24] Xinjie Guo et al. “Modeling and experimental demonstration of a Hopfield network analog-to- + digital converter with hybrid CMOS/memristor circuits”. In: Frontiers in neuroscience 9 (2015), + p. 488. +[25] SG Hu et al. “Associative memory realized by a reconfigurable memristive Hopfield neural net- + work”. In: Nature communications 6.1 (2015), p. 7522. +[26] Sukru B Eryilmaz et al. “Brain-like associative learning using a nanoscale non-volatile phase change + synaptic device array”. In: Frontiers in neuroscience 8 (2014), p. 205. +[27] Brendan P Marsh et al. “Enhancing associative memory recall and storage capacity using confocal + cavity QED”. In: Physical Review X 11.2 (2021), p. 021048. +[28] Khalid Musa et al. “Dense Associative Memory in a Nonlinear Optical Hopfield Neural Network”. + In: arXiv preprint arXiv:2506.07849 (2025). +[29] Carver Mead and Mohammed Ismail. Analog VLSI implementation of neural systems. Vol. 80. + Springer Science & Business Media, 2012. +[30] Richard W Hamming. “Error detecting and error correcting codes”. In: The Bell system technical + journal 29.2 (1950), pp. 147–160. +[31] Paul J Werbos. “Backpropagation through time: what it does and how to do it”. In: Proceedings + of the IEEE 78.10 (2002), pp. 1550–1560. +[32] Patrick Kidger. “On Neural Differential Equations”. PhD thesis. University of Oxford, 2021. +[33] Zihang Dai et al. “Transformer-xl: Attentive language models beyond a fixed-length context”. + In: Proceedings of the 57th annual meeting of the association for computational linguistics. 2019, + pp. 2978–2988. +[34] Jacob Sillman. “Analog Implementation of the Softmax Function”. In: arXiv preprint arXiv:2305.13649 + (2023). +[35] John J Hopfield and David W Tank. “Computing with neural circuits: A model”. In: Science + 233.4764 (1986), pp. 625–633. +[36] Aldo Pena Perez and Franco Maloberti. “Performance enhanced op-amp for 65nm CMOS tech- + nologies and below”. In: 2012 IEEE International Symposium on Circuits and Systems (ISCAS). + IEEE. 2012, pp. 201–204. + + + 13 + Figure 7: Circuit for a single neuron. + + +[37] Rida S Assaad and Jose Silva-Martinez. “The recycling folded cascode: A general enhancement of + the folded cascode amplifier”. In: IEEE Journal of Solid-State Circuits 44.9 (2009), pp. 2535–2542. +[38] Alec Yen and Benjamin J Blalock. “A High Slew Rate, Low Power, Compact Operational Ampli- + fier Based on the Super-Class AB Recycling Folded Cascode”. In: 2020 IEEE 63rd International + Midwest Symposium on Circuits and Systems (MWSCAS). IEEE. 2020, pp. 9–12. +[39] Mohammad H Naderi, Suraj Prakash, and Jose Silva-Martinez. “Operational transconductance + amplifier with class-B slew-rate boosting for fast high-performance switched-capacitor circuits”. + In: IEEE Transactions on Circuits and Systems I: Regular Papers 65.11 (2018), pp. 3769–3779. +[40] Franz Schlögl and Horst Zimmermann. “A design example of a 65 nm CMOS operational amplifier”. + In: International Journal of Circuit Theory and Applications 35.3 (2007), pp. 343–354. + + +A Neuron Design +Figure 7 shows the circuit design of a single neuron, with labels corresponding to this being a hidden +neuron at index µ. We derive the dynamics of the neuron internal state hµ and activation output voltage +fµ . We proceed using only Kirchhoff’s Current Law (KCL) and the definition of an ideal op-amp. + +Assumptions and conventions. + • Ideal op-amps: infinite open-loop gain, infinite input impedance (no input current), zero output + impedance. Under stable negative feedback this enforces a virtual short V+ = V− . + • Current Jµ : we define Jµ as the current which flows from fµ to mµ through R1 . + + • Op-amp input labels: We denote the inverting and noninverting inputs of each op-amp explicitly, + e.g. U 2− for the inverting input of U2, U 3+ for the noninverting input of U3, etc. + • Node labels: Label mµ as the output of U1, sµ as the output of U2, and dµ as the output of U3. + The neuron pre-activation state is labeled hµ , and the post-activation state is labeled fµ . Voltage + bµ (as an ideal voltage source) drives the bias for this neuron. Voltages hµ , bµ , and fµ correspond + directly to the state variables in equation (1). + + + + + 14 +Block U1: buffer of activation voltage fµ . Op-amp U1 buffers the output of the activation function +f (·) and drives the output of the neuron, fµ . Because no current can flow into U 1− , all the current +flowing into this neuron must flow through R1 to mµ and is sourced or sunk by U1’s output node. + +Block U2: non-inverting stage producing sµ from fµ and mµ . The positive input of U2 is +U 2+ = fµ , and by U2’s virtual short, the negative input U 2− = U 2+ = fµ . By KCL at U 2− , + + U 2− sµ − U 2− R9 + = ⇒ sµ = 1 + fµ (19) + R10 R9 R10 + +Block U3: non-inverting stage producing dµ from sµ , bµ , and mµ . By KCL at the positive input +of U3, + bµ − U 3+ sµ − U 3+ U 3+ R4 R5 bµ + R3 R5 sµ + + = ⇒ U 3+ = (20) + R3 R4 R5 R4 R5 + R3 R5 + R3 R4 +KCL at the negative input of U3 gives us + + mµ − U 3− −U 3− U 3− − d µ 1 1 R8 mµ + + = ⇒ dµ = U 3− 1 + R8 + − (21) + R6 R7 R8 R6 R7 R6 +Virtual short of U3 means U 3− = U 3+ . Combining equations (20) and (21), get + R6 R7 + R8 (R6 + R7 ) R4 R5 bµ + R3 R5 sµ R8 + dµ = · − mµ (22) + R6 R7 R4 R5 + R3 R5 + R3 R4 R6 + +Dynamics of RC circuit. R2 and C1 form an RC circuit driven by voltage dµ . The voltage across +the capacitor hµ follows the relation + dhµ + R2 C 1 = −hµ + dµ + dt + R6 R7 + R8 (R6 + R7 ) R4 R5 bµ + R3 R5 sµ R8 + = −hµ + · − mµ (23) + R6 R7 R4 R5 + R3 R5 + R3 R4 R6 + P +With incoming current. Take the incoming current PJµ = i ξµi (gi − fµ ). This produces a voltage +drop across R1 such that mµ = fµ − R1 Jµ = fµ − R1 i ξµi (gi − fµ ). Then, the dynamics of hµ from +equation (23) are + dhµ R6 R7 + R8 (R6 + R7 ) R4 R5 bµ + R3 R5 sµ R8 + R2 C1 = −hµ + · − (fµ − R1 Jµ ) (24) + dt R6 R7 R4 R5 + R3 R5 + R3 R4 R6 +Substituting in sµ from equation (19) and Jµ : + + R9 ! + dhµ R6 R7 + R8 (R6 + R7 ) R R b + 4 5 µ + R R + 3 5 1 + R10 fµ R8 X +R2 C1 = −hµ + · − fµ − R 1 ξµi (gi − fµ ) + dt R6 R7 R4 R5 + R3 R5 + R3 R4 R6 i + (25) + +Equal-resistance special case. Set R1 = R3 = R4 = R5 = R6 = R7 = R8 . Then, equation (25) +reduces to + dhµ R9 X + R2 C 1 = −hµ + bµ + fµ + ξµi (gi − fµ ) (26) + dt R10 i + + +Selection of R9 /RP10 self-term gain. Evidently, in order to match the form of equation (1), we need +to cancel the −fµ i ξµi term that appears on the right hand side of equation (26). The R9 /R10 term +allows us to do that by setting + R9 X + = ξµi (27) + R10 i + +Taking equation (27)’s assignment to R9 and R10 simplifies equation (26) into + dhµ X + R2 C1 = ξµi gi − hµ + bµ (28) + dt i +which exactly matches our desired dynamics. + + + 15 +Figure 8: Crossbar Array. Each pentagon contains a neuron of design in Figure 7. In this layout we +have flipped the crossbar array, so that index µ runs horizontally and index i runs vertically. + + +A.1 Activation function +The voltage across C1 gives us the dynamics of the neuron internal state hµ . Figure 7 contains a block +representing a nonlinear amplifier, denoted f (·), whose input is hµ and whose output is fµ = f (hµ ). This +voltage is buffered with U1 onto the neuron output line, labeled fµ , which is what other neurons “see” +in the crossbar array. The chosen activation function does not affect the rest of the dynamics of the +neuron. Particularly, the activation function need not be element-wise: a vector-wise activation function +like softmax can be readily applied instead. + +A.2 Neurons interacting in a network +So far we have examined the dynamics + P of a single neuron, treating as an assumption that the neuron will +receive an incoming current Jµ = i ξµi (gi − fµ ). Now, we will show how to wire these neurons together +to realize this. Figure 8 shows the simplest DenseAM construction where each pentagonal node is a +circuit of design in Figure 7. Each neuron exposes a single node whose voltage is driven at the activation +of the neuron, and which accepts an incoming current which it uses to drive its dynamics. Each hidden +neuron fµ is connected to a visible neuron gi via a resistance + P Rµi = 1/ξµi that is the inverse of the weight +it represents. The current flowing into node fµ is Jµ = i R1µi (gi − fµ ), which is the assumption needed +for equation (24). This same analysis holds for other hidden and visible neurons, and so together they +realize the large dynamical system of (1). + +A.3 SPICE Netlist +Following is the SPICE netlist for the single neuron circuit, using ideal op-amps. Component values are +omitted for brevity. There is no nonlinearity here; adding one would be a matter of inserting a nonlinear +amplifier between node h µ and XU1’s positive terminal. +R1 f_µ m_µ +XU1 f_µ h_µ m_µ opamp Aol=100K GBW=10Meg +XU2 u2- f_µ s_µ opamp Aol=100K GBW=10Meg +R2 u2- 0 +R3 s_µ u2- +R4 u3+ s_µ +R5 u3+ 0 +XU3 u3- u3+ d_µ opamp Aol=100K GBW=10Meg +R6 u3- m_µ +R7 d_µ u3- +R8 d_µ h_µ +C1 h_µ 0 + + + 16 + Figure 9: Softmax circuit design + + +V§b_µ N001 0 +R9 u3+ N001 +R10 u3- 0 + + +B Softmax Circuit +For demonstration purposes, we follow the construction of an analog softmax circuit using bipolar junc- +tion transistors (BJTs) described in [34]. Figure 9 shows the design of a four-way softmax circuit using +BJTs. The softmax function we aim to produce is: + ezi + softmaxi = PN , i = 1, . . . , N (29) + zj + j=1 e + + For the µth BJT in the circuit, the collector current IC,µ can be expressed in terms of the base voltage +hµ and the emitter voltage VE when in the forward-active mode as: + hµ −VE + IC,µ = Is eVBE /VT , VBE,µ = hµ − VE , ⇒ IC,µ = IS e VT + (30) +where Is is the BJT’s saturation current and VT is the thermal voltage. Assuming large BJT β (note: +this β is unrelated to the softmax β)2 , we can neglect base currents IC,µ = IE,µ . Applying KCL at + PN +the shared emitter node VE , the total current IEE = µ=1 IC,µ . We can expand the expression for the +collector currents to get the currents in terms of node voltages: + Nh + X + IEE = IS e(hµ −VE )/VT + µ=1 + Nh + X IS ehµ /VT + = (31) + µ=1 + eVE /VT + +Simultaneously, the current IEE is also fixed by the ideal current source, so IC,µ can also be expressed + I +as the ratio of the branch current to the total current: IC,µ = IC,µ + EE + IEE . Plugging in (30) for IC,µ and +(31) for IEE in the denominator and canceling the term containing VE , + ehµ /VT + IC,µ = PNh IEE (32) + hj /VT + j=1 e + +This already looks very much like the ideal softmax function. The voltage at node fi is created by +current flowing through resistor Ri , producing a voltage drop relative to VCC . Specifically, the voltage + hµ /VT +fµ = VCC − PNeh hj /VT IEE Rµ . When IEE Rµ = 1, this voltage fµ is a negated and shifted softmax in + j=1 e +the range of 1 volt. This scale and negation can be easily corrected with an op amp, which is also needed +to isolate the node and prevent loading. Note that VCC must be chosen to be positive supply in order +for the BJTs to remain in the forward-active mode. + 2 In BJTs, β denotes the ratio of the collector current to the base current. High BJT β indicates the transistor is able to + +amplify a small base current into a much larger collector current, allowing the BJT to function as an amplifier or switch. +A high β reflects that the BJT can efficiently transmit carriers from emitter to collector, without losing them to the base. + + + 17 + Parameter Value + RF 1000 Ω + RT 1 Ω + R1 1 Ω + R2 , R3 , . . . , R8 10 000 Ω + RS 40 Ω + C 10 µF + a3 0 V + b1 0 V + b2 −1 V + b3 −1 V + b4 −1 V + + Table 2: Component and parameter values. + + +C XOR DenseAM Circuit +Figure 10 is a full circuit diagram of the DenseAM that solves the XOR problem. Given input voltages +at V1, V2∈ {0, 1}, the output voltage at g3 is the result of the XOR operation between V1 and V2. In +this model, the visible neuron is linear, and the hidden neurons share a softmax activation function im- +plemented by a set of bipolar junction transistors. Table 2 lists the component values used in simulation. + + +Visible neurons. In the XOR task, only one visible neuron is left evolving, corresponding to the output +column of the truth table. As such, the first two neurons are clamped to the input voltages, represented +by V1 and V2. The third visible neuron, highlighted in blue, is a linear unit with no nonlinear activation: +the internal state voltage v3 directly drives the output, setting g3 = v3 . This is the same circuit described +in Appendix A, except where the activation block is not present. + +Hidden neurons. The XOR task requires four hidden neurons, highlighted in green. These are iden- +tical circuit constructions with the exception of the voltage sources bµ for the biases, which are set +according to the values in Table 2. Unlike the visible neuron, the hidden neurons have a softmax activa- +tion function, such that fµ = softmaxµ (h). + +Softmax activation function. The red highlights the same softmax circuit described in Appendix B, +comprised of BJT transistors, resistors, a voltage source for VCC and a current source for IEE . We +use the 2N5088 transistors in our model, reflecting a standard and widely available BJT. Noninverting +buffers (U10, U11, etc.) are used to prevent loading effects on the state capacitors Cµ from current draw +of the BJT base in forward-active mode. As discussed in Appendix B, the softmax circuit itself produces +an output voltage of + ezi + softmax(z)i = VCC − PN , i = 1, . . . , N + zj + j=1 e + +When VCC = 5V as in this circuit, this requires extra circuitry, highlighted in yellow, to shift and negate +the softmax output. This is done by first buffering the voltage output to prevent loading effects, followed +by a summing op amp that subtracts VCC and inverts the softmax output. For the first hidden neuron +h1 (lower left of figure), op-amp U2 buffers the voltage output, while U1 is configured in an inverting +summing configuration to add -5V (the inverse of VCC ) to the buffered voltage output, producing the +correct softmax output. + +Weight matrix. The weight matrix is comprised of resistors R1 -R12 that represent the weight matrix +ξ. These are set directly according to the XOR truth table, where each row corresponds to one hidden +neuron. A boolean value of 1 (RT ) is set to be a high conductance (1Ω), while a boolean value of 0 (RF ) +is set to be a relatively small conductance (1kΩ). + The gain si /gi governing the value of si is set to be the sum of the resistances in that neuron’s crossbar +column. The column of resistances for neuron 1 has 3 RF resistances, which sum to 3 × 10−3 . Hence, + + + 18 +19 + Figure 10: Full schematic for XOR DenseAM built with 1 evolving linear visible neuron and 4 hidden neurons with softmax activation. Blue: visible neuron. + Green: hidden neurons. Yellow: buffers for softmax activation circuit. Red: analog softmax circuit. +neuron 1’s R47 /RR46 = 3/1000. The crossbar resistances for neuron 2, 3, and 4 have 2 RT resistances +and one RF resistance, which sums to approximately 2. Hence, we approximate R59 /R56 = 2000/1000 +and similarly for hidden neurons 3 and 4. + + +D Design and implementation variations +A large design space remains open across analog electronics and other substrates for realizing DenseAMs, +with clear speed–energy–area–precision trade-offs. In electronics, the core primitives admit multiple re- +alizations: passive, nonvolatile weights (e.g., memristors, triode-region or floating-gate transistors, and +other programmable conductors); active, gained weights via OTAs; and nonlinearities via diode clamps, +reverse-biased diode/BJT exponentials, MOS quadratic regions, or translinear blocks. Architectures in +the spirit of [35, 23] are compact but couple synaptic values to neuronal time constants, making dynamics +drift when a single weight changes—problematic for learning and consistent timing—whereas our decou- +pled neuron preserves a fixed time constant under weight updates. Simpler neuron/network topologies +likely exist and can be attractive in resource-constrained regimes, provided their deviations from the +target ODEs are validated not to degrade performance. Beyond CMOS, photonics (e.g., overdamped, +low-Q microring resonators) can naturally implement first-order ODEs and can offer extreme bandwidth +with distinct calibration and noise constraints. Across these options, open problems include robust +weight storage/programmability and drift control, mixed-signal learning rules compatible with device +limits, scaling under current/GBW/SR constraints, tolerance to mismatch/noise, and algorithm–circuit +co-design to exploit substrate-specific advantages. + + +E Scaling of inference time +There are two conditions under which inference times should be studied, dependent on the softmax +temperature β. In the low-β regime, the DenseAM reaches equilibria with multiple hidden neurons +“competing” in the softmax, while in the high-β regime, the DenseAM reaches equilibria with only one +hidden neuron “winning out” in the softmax. Intuitively, the high-β regime corresponds to exact memory +recall, while the low-β regime corresponds to interpolation. The XOR and Hamming (7,4) code are in +the high-β regime, while the energy transformer lies in the low-β regime. In both regimes, we find that +the DenseAM converges in time that is constant with respect to the number of neurons. + +Assumptions. +(A1) There is a per-synapse device limit of 0 ≤ ξµi ≤ Gmax where Gmax is the maximum conductance + set by the physics of the crossbar crosspoints. Because f is the output of a softmax so fµ ≤ 1 ∀µ, + this means + X + ξµi fµ ≤ Gmax (33) + µ + + so the RHS of the visible neuron dynamics is O(1). + There exist both column-sum and row-sum budgets that are enforced by the hardware, since each + neuron’s output stage can only source/sink a finite amount of current while maintaining GBW/SR + margins. This dictates a per-column and per-row conductance budget to stay within this maximum + current, resulting in + Nv + X Nh + X + ξµi ≤ Cr ∀µ, ξµi ≤ Cc ∀i (34) + i µ + + + Weights can only be positive since conductances can only be positive, so ξµi ≥ 0. + As a corollary of (A1), note also that we can bound ∥ξ µ ∥2 ≤ S ∀µ, and since ∥ξµ ∥2 ≤ ∥ξ µ ∥1 , then + ∥ξ µ ∥2 ≤ Cc ∀µ. +(A2) Bounded biases. |ai | ≤ A, |bµ | ≤ B for all i, µ. In realistic regimes, this typically holds, for + example the typical choice in boolean functions of bµ = − β2 ∥ξ µ ∥2 (seen in Section 5.1). + + + + 20 +Model. Take the system of equation (1) with a softmax activation on hidden neurons and an identity +activation on visible neurons. For clarity we assume 0 biases on visible neurons, but they do not change +the analysis. + + τv v̇ = ξ⊤ f + a − v, τh ḣ = ξv + b − h, f = softmaxβ (h) (35) + +Integrating out the hidden units, + + τv v̇ = ξ ⊤ f (v) − v, (36) + + f (v) = softmax β(ξv + b) (37) + +yields the effective energy function expressed in terms of visible neurons: + 1 1 X + E(v) = ∥v∥2 − log exp β ξ ⊤ + µv+b (38) + 2 β µ + + +where ∇E(v) = v − ξ ⊤ f (v). Because τv v̇ = −∇E(v), we see that the dynamical trajectory causes the +energy to monotonically decrease over time: + d 1 + E(v(t)) = ∇E(v(t))⊤ v̇ = − ∥∇E(v(t))∥2 ≤ 0 (39) + dt τv + +E.1 Low-β regime +The energy landscape in the low-β regime exhibits uniform strong convexity, so the gradient flow dy- +namics cause the energy gap to decay exponentially, reaching an ϵ-fraction of the original energy gap +in constant time. To show E(v) is α-strongly convex, we must show ∇2 E(v) ⪰ αI for some α > 0. +This means that all the eigenvalues of the Hessian are ≥ α. Equivalently, λmin (∇2 E) ≥ α. Denote +G(f ) = Diag(f ) − ff ⊤ ⪰ 0, which is the Jacobian of the softmax function f (v) = softmax(β(ξv + b)). + + ∇2 E(v) = I − βξ ⊤ G(f )ξ (40) + + λmin ∇2 E(v) = λmin I − βξ⊤ G(f )ξ + + (41) + + = 1 − βλmax ξ ⊤ G(f )ξ (42) + + ⇒ ∇2 E(v) ⪰ 1 − βλmax ξ ⊤ G(f )ξ I (43) + +Because G(f ) ⪯ Diag(f ) ⪯PI is PSD and therefore ξG(f )ξ ⊤ is also PSD, and G(f ) is a probability- +weighted covariance where µ fµ = 1, + X + λmax (ξ ⊤ G(f )ξ) ≤ tr(ξ⊤ G(f )ξ) ≤ fµ ∥ξ µ ∥2 ≤ max ∥ξ µ ∥2 (44) + µ + µ + + +Denote S 2 = maxµ ∥ξ µ ∥2 ≤ Cc as in (A1). Therefore, the Hessian of E can be bounded as + + ∇2 E(v) ⪰ (1 − βS 2 )I = αI (45) + +where α = 1 − βS 2 . Then α > 0 when β < 1/ maxµ ∥ξ µ ∥2 . This is a sufficient (but not necessary) +condition for the system to be in the low-β (uniformly convex) regime, where the softmax is diffuse +enough that its covariance term does not contribute so much negative curvature as to overwhelm the +positive curvature contributed by the identity term. In this regime, the uniform lower bound on the +Hessian implies α-strong convexity, which gives the PL inequality + 1 + ∥∇E(v)∥2 ≥ α(E(v) − E ∗ ) (46) + 2 +Together with (39), this allows us to bound the time constant of gradient flow: + + d 1 2α + (E(v(t)) − E ⋆ ) = − ∥∇E(v(t))∥2 ≤ − (E(v(t)) − E ⋆ ) (47) + dt τv τv + + + 21 +If the curvature is bounded below by α, then the gradient magnitude grows at least linearly with distance +to the minimum, ensuring the energy function is “steep enough” to ensure exponential convergence. +Integrating, + 2α + E(v(t)) − E ⋆ ≤ (E(v(0)) − E ⋆ )e− τv t (48) +This indicates exponential decay of the energy gap. In order to reach an ϵ-fraction of the original energy +gap, this takes time + τv 1 + T (ϵ) ≤ log = O(τv log(1/ϵ)) (49) + 2α ϵ +which is entirely independent of system size Nv and Nh . In the energy transformer case, this means that +convergence time is entirely independent of context length L and token dimension D. + +E.2 High-β regime +E.2.1 TI : Basin selection +Denote + sµ (v) := ξ⊤ + µ v + bµ , m(v) := max sµ (v), f := softmax(βs) (50) + µ + +Define the basin of attraction around the winning softmax logit k by the margin γ > 0: + Bk (γ) = {v : sk (v) − max sj (v) ≥ γ} (51) + j̸=k + +Let TI be the first time t such that v(t) ∈ ∪k Bk (γ). Defining the softmax component of the energy +function (38) as + Nh + 1 X + LSEβ (s) = log eβsµ + β µ=1 + +then for every v, we can bound the LSE as + 1 + m(v) ≤ LSEβ (s(v)) ≤ m(v) + log Nh (52) + β +Thus, the “softmax slack” δ(v) := LSEβ (s(v)) − m(v) obeys 0 ≤ δ(v) ≤ β1 log Nh . In the high-β regime, +there are no critical points other than the softmax basins (those within ∪k Bk (γ) for any reasonable +γ > ϵ > 0). To reduce δ from its initial value to the cusp of one of the basins requires dissipating at most + 1 + ∆Esoftmax ≤ log Nh (53) + β +∂E +∂vi = −τv v̇i , and outside winning basins τv v̇i ∼ 1, so the squared magnitude of the gradient grows at +least linearly in Nv : + Nv 2 + 2 + X ∂E + ∥∇E(v)∥ = ≥ cNv (54) + i=1 + ∂vi + +for some c > 0 independent of Nv and Nh for all v in the trajectory outside a winning basin. Therefore, +the energy dissipation rate satisfies + 1 c + −Ė(t) = ∥∇E(v(t))∥2 ≥ Nv (55) + τv τv + Under assumptions (A1)–(A2), the visible state v remains in a bounded box, so the quadratic part of +the energy contributes at most O(Nv ) to the energy difference between any two points on the trajectory. +Since the energy dissipation rate during TI scales proportionally to Nv , the quadratic component of +the energy contribution is dissipated in constant time. The only nontrivial Nh dependence is due to the +softmax slack. Together with the bound on ∆Esoftmax , the total time this phase takes is characteristically + + τv log Nh + TI = O (56) + β Nv + + 22 +E.2.2 TII : Contractive convergence within a winning basin +Find a basin Bk (γ) that is entered at tin = TI . We will now show local strong convexity within this +basin, allowing us to invoke the PL inequality and find exponential convergence within the basin. Define +G := Diag(f ) − ff ⊤ . First, consider that the non-winning softmax mass is 1 − fk , which is + X + 1 − fk = fj ≤ (Nh − 1)e−βγ (57) + j̸=k + + +Additionally, since ∥f ∥2 = fk2 + 2 2 + P + j̸=k fj ≥ fk and 0 ≤ fk ≤ 1, + + + λmax (G(f )) ≤ tr(G(f )) = 1 − ∥f ∥2 ≤ 1 − fk2 ≤ 2(1 − fk ) ≤ 2(Nh − 1)e−βγ (58) + +Hence, with S 2 = maxµ ∥ξ µ ∥2 , + + λmax (ξ ⊤ G(f )ξ) ≤ S 2 λmax (G(f )) ≤ 2S 2 (Nh − 1)e−βγ (59) + +This gives a bound on the largest eigenvalue of G(f ) in a way that incorporates the softmax beta. + Now, we can show local strong convexity in the winning basin: + + ∇2 E(v) = I − βξ ⊤ G(f )ξ ⪰ (1 − β2S 2 (Nh − 1)e−βγ )I ≡ α(β, γ)I (60) + +for all v ∈ Bk (γ). Particularly, if + 1 + e−βγ (Nh − 1) ≤ (61) + 4βS 2 + +then α(β, γ) ≥ 12 , independent of Nh , Nv . Note that this is always possible: if the softmax is not peaked +enough to make this inequality true, simply keep moving in trajectory “Phase I” for a little longer until +the margin γ grows slightly larger such that the condition holds true. This strong convexity within Bk (γ) +implies the PL inequality + 1 + ∥∇E(v)∥2 ≥ α(β, γ)(E(v) − E ⋆ ), ∀v ∈ Bk (γ) (62) + 2 +Therefore, along the trajectory within the basin for times t ≥ tin , + + d 1 2α(β, γ) + E(v(t)) − E ⋆ = − ∥∇E(v(t))∥2 ≤ − E(v(t)) − E ⋆ + + (63) + dt τv τv +Integrating, + 2α(β,γ) + E(v(t)) − E ⋆ ≤ e− (t−tin ) + E(v(tin )) − E ⋆ + + τv (64) + +Impose a relative-to-initial convergence criteria: + + E(v(t)) − E ⋆ ≤ ϵ E(v(0)) − E ⋆ , + + ϵ ∈ (0, 1) + +Since E is non-increasing along the trajectory, E(v(tin )) − E ⋆ ≤ E(v(0)) − E ⋆ , so it suffices that + 2α(β,γ) + e− τv (t−tin ) + ≤ϵ + +Hence the in-basin time satisfies + + τv 1 1 + TII ≤ log = O τv log (65) + 2α(β, γ) ϵ ϵ + +which is size-free of Nh and Nv . + + + + + 23 +E.2.3 Combined bound +Altogether, in the high-β regime, to reach a relative-to-initial tolerance of + E(v(t)) − E ⋆ ≤ ϵ E(v(0)) − E ⋆ + + (66) +the combined convergence time satisfies + + τv log Nh 1 + T (ϵ) = O + O τv log (67) + β Nv ϵ + | {z } | {z } + winner selection (TI ) convergence within basin (TII ) + +For fixed ϵ, β, and τv , TII is independent of Nv and Nh , while TI carries all the model-size dependence. +The dependence of the convergence time on Nh and Nv in the high-β regime is + + τv log Nh + T (ϵ) = O . (68) + β Nv +The convergence time is at most logarithmic in the number of hidden neurons Nh , and actually decreases +as 1/Nv in the number of visible neurons. + +E.3 Limitations +Our analysis assumes that the timescales of the crossbar array are much faster than the fastest neuronal +timescales. In practice, as the crossbar array gets bigger, it may contribute to the time scales of the +entire system, since wires have non-zero capacitances. Once the size of the crossbar array reaches the +point when it significantly modifies the time scales of the neurons, our analysis and the scaling argument +becomes invalid. For this reason, one cannot scale this design to infinitely large sizes. Analyzing that +boundary is outside the scope of our paper, because it is dependent on fabrication and design parameters, +which is a different level of abstraction than our present paper. + + +F Design invariance under voltage scaling +Given hardware constraints of Gmax , Cc , and Cr , we can still implement models with arbitrarily large +weights. Convergence bounds rely on the weight matrix constraints, which can be made feasible by +global normalization at the hardware level, keeping the effective model weights unchanged. Consider the +scaling factor for any non-negative ξ: + ( ) + Gmax Cc Cr + κ = min 1, , P , P (69) + maxµ,i ξµi maxi µ ξµi maxµ i ξµi + +Set ξ̃ = κξ. Then, ξ̃ satisfies all the hardware constraints of assumption (A1): + X X + 0 ≤ ξ˜µi ≤ Gmax , ξ˜µi ≤ Cr ∀µ, ξ˜µi ≤ Cc ∀i (70) + i µ + +So any ξ matrix can be mapped onto budgets with one scalar κ. Consider the pre-softmax arguments +for the hidden neurons: if we scale weights ξ → ξ̃ = κξ, rescale the voltage unit v → ṽ = κv and biases +b → b̃ = κ2 b and set β̃ = β/κ2 , then + ⊤ + β̃(ξ˜µ ṽ + b̃) = β(ξ ⊤ + µ v + b) (71) + +so the softmax outputs f and the system’s attractors are unchanged. The visible ODE τv v̇ = ξ⊤ f (v) − v +is preserved up to units, as the κ terms can be absorbed into the gain of U2 and U3 without affecting the +convergence time bounds. + + +G Scaling of energy consumption +The energy consumption of DenseAM circuits can be broken up into two parts: the energy dissipated +by the weights as a result of Ohm’s Law, and the energy from engineering overhead found in amplifiers +and active circuitry. The energy dissipated by the weights in the crossbar array can be expressed as the +integral of the power dissipated by each resistor of resistance Rµi from time 0 until convergence at Tconv . + + + 24 +Energy consumption of weights. Let the neuron output voltages be proportional to activations: +ui = κgi and wµ = κfµ , where κ is a fixed voltage scale. We assume rail-bounded outputs |ui | ≤ κ and +|wµ | ≤ κ (by Appendix F, global rescaling of ξ, voltages, and β preserves the DenseAM dynamics, so +this choice of κ does not affect behavior.) The instantaneous power in the resistive crossbar is: + X + Pweights (t) = ξµi (ui − wµ )2 (72) + i,µ + P P +Using the row/column conductance budgets µ ξµi ≤ Cc and i ξµi ≤ Cr (Appendix E) and the +inequality (a − b)2 ≤ 2a2 + 2b2 , + + X X + Pweights (t) ≤ 2 ξµi u2i + ξµi wµ2 (73) + i,µ i,µ + ! !! + X X X X + =2 u2i ξµi + wµ2 ξµi (74) + i µ µ i + ! + X X + ≤ 2 Cc u2i + Cr wµ2 (75) + i µ + + 2 2 2 + P P +If the hidden layer uses a softmax activation, then +P 2 µ fµ ≤ 1 and so µ wµ ≤ κ ; and rail bounds give + 2 + i ui ≤ Nv κ . Therefore, + + Pweights (t) ≤ 2κ2 (Cc Nv + Cr ) = O(Nv ) (76) + +Therefore, a system taking time T conv to converge results in an energy consumption of + Z T + Eweights = Pweights (t)dt ≤ 2κ2 (Cc Nv + Cr )T conv (77) + 0 + +According to the convergence time bounds of Appendix E, T conv = O(τv ). Thus, Eweights = O(Nv ), as +a function of system size. + +Energy consumption of capacitors. Let each neuron node voltage be bounded by hardware limits +|ui (t)|, |wµ (t)| ≤ κ. Charging a capacitor of capacitance C from a supply through a resistive path draws +CV 2 from the power supply. The number of times each capacitor charges is finite because the Lyapunov +energy of the DenseAM forbids limit cycles. This means the total supply energy per node can be bounded +by a constant. Therefore, the total energy needed to (re)charge all neuron capacitors is bounded by + Nh + Nv + ! + (v) + X X + 2 (h) + Ecapacitors ≤ O(1) · κ Ci + Cµ = O(Nv + Nh ) (78) + i=1 µ=1 + + +Energy consumption of amplifiers, bias, control, and overhead. Per neuron, the energy expen- +diture to amplifier inefficiency, bias terms, and general overhead do not depend on system size. For a +runtime of duration T conv , the energy consumption of these elements in the entire network scales as + + Eother = O((Nv + Nh )T conv ) (79) + +Combined energy consumption. All together, the total energy consumption can be written as + + Etotal = O(Nv + Nh ) (80) + + +H Model Specifications and Details +Table 3, Table 4, and Table 5 summarize the model design for the XOR, Hamming (7,4), and parity +DenseAM models. + + + + 25 + Table 3: XOR model specification + +Visible neurons vi Nv = 3 (inputs v1 , v2 clamped to {0,1}; output v3 free) +Hidden neurons hµ Nh = 4 (one per truth-table row) + PNv 2 +Visible activation and Lagrangian Identity: gi = vi , Lv = 21 i=1 vi + NPh βhµ +Hidden activation and Lagrangian Softmax: fµ = softmax(βhµ ), Lh = β1 log e + µ=1 +Visible biases ai = 0 + PNv 2 +Hidden biases bµ = − 12 i=1 ξµi + 0 0 0 + 0 1 1 +Weights ξ ξ ∈ {0, 1}4×3 , rows encode memories: ξ = 1 0 1 + + + 1 1 0 +Inference protocol Clamp (v1 , v2 ) to input values; read out v3 at convergence + + + + + Table 4: Hamming (7,4) model specification + +Visible neurons (Nv ) 7 (codeword bits) +Hidden neurons (Nh ) 16 (one per valid codeword) +Visible activation Identity: gi = vi +Hidden activation Softmax over µ ∈ {1, . . . , 16} with temperature β +Visible biases ai = 0 + PNv 2 +Hidden biases bµ = − 21 i=1 ξµi +Weights ξ ξ ∈ {0, 1}16×7 , each row is a valid Hamming(7,4) codeword +Inference protocol Initialize visible neurons to corrupted 7-bit input codeword; let all visible and + hidden neurons evolve; converged visible neurons give the corrected codeword + + + + + Table 5: 8-bit parity model specification + +Visible neurons vi Nv = 16 (dimension of embedding D) +Hidden neurons (energy attention) hattn + A Nhattn = 8 (context length L) +Hidden neurons (Hopfield network) hhopf + µ Nhhopf = 16 (Hopfield network memories M ) +Hidden neurons (total) Nh = 24 (L + M ) +Visible activation Identity: gi = vi +Hidden activation (energy attention) Softmax: fAattn = softmax(βhattn )A for A = 1, . . . , L +Hidden activation (Hopfield network) ReLU: fµhopf = max (hhopf + µ , 0) for µ = 1, . . . , M +Weights (energy attention) ξattn ∈ RL×D , where ξattn + A is embedded A’th context token +Weights (Hopfield network) ξ hopf ∈ RM ×D , static after training +Inference protocol Embed L context tokens to obtain ξ attn . Let visible neurons + evolve until convergence + + + + + 26 +H.1 Bit string energy transformer implementation +As described in Table 5, our trained model uses an embedding matrix of 2 × D = 32 parameters, the +Hopfield network with D × M = 256 parameters, an additional D × 2 = 32 parameter matrix to decode +embeddings to logits, a total of D + L + M = 40 neuron bias terms, and 2 biases for the linear decoder. +This is a total of 362 parameters. + In training and inference we use time constants τv = 0.1 and τh = 0.01. We train with Euler steps of +1e-3, and test with Euler steps of 1e-4 for a time horizon of T = 1 second. Jax’s automatic differentiation +was used to implement backpropagation through time. We encourage the model to reach fixed points +by penalizing v̇ at time T. This yields models that are more robust to hardware imperfection due to the +intrinsic stability of attractor points. The convergence to an attractor also means the inference remains +stable to mismatch and delay in timing during readout. + + +I Hardware analysis +I.1 Hardware speed analysis +As discussed in subsection 7.1, the convergence time of analog DenseAMs is governed not by system size, +but rather primarily by the timescales of the dynamics in hardware. These timescales are set by the time +constants τv and τh . The smaller these time constants, the faster the dynamics move, and the faster the +system converges. In this section, we derive bounds on the minimum time constant min{τv , τh } of the +DenseAM, which is limited by the constraints of active components like amplifiers. + The maximum speed of neuronal dynamics is limited by the ability of active stages (op-amps/buffers) +to track changing signals. If the input slope to an active stage exceeds its slew rate (SR), the output +distorts; if the signal spectrum approaches or exceeds the stage’s closed-loop bandwidth, attenuation +and phase lag appear. Here, we derive lower bounds on the time constants τv , τh imposed by (i) finite +gain–bandwidth product (GBW) and (ii) finite SR of the three active stages in the neuron design (Ap- +pendix A). Without loss of generality we will express the derivation for the hidden neurons, with the +derivations for visible neurons following by symmetry. Throughout, define the following: + + • State swing: |vi (t)| ≤ Av , so that |v̇i | ≲ Av /τ . Similarly, |hµ (t)| ≤ Ah , so that |ḣµ | ≲ Ah /τ . + • Activation swing: Visible activation g(·) is Lipschitz with slope bound Lg = supx |g ′ (x)|. Then, + |ġi | ≤ Lg |v̇i | ≤ Lg Av /τ . Similarly, hidden activation f (·) is Lipschitz with slope bounded by + Lf = supx |f ′ (x)|. Then, |f˙µ | ≤ Lf |ḣµ | ≤ Lf Ah /τ . + + • Weights ξ ≥ 0. Hardware normalization gives + P per-row/column conductivity budgets, so the self- + term gain for hidden neuron µ is Aself,µ = i ξµi = O(1). +We will derive three independent lower bounds and then take the max: + + τmin ≥ max{ τGBW , τSR , τI−limit } (81) + | {z } |{z} | {z } + tracking small signals edge/large-signals output current + + +I.1.1 Gain-bandwidth product bound +For a single-pole op-amp with gain-bandwidth product GBW in a closed-loop configuration with loop +gain ACL , the −3db bandwidth is fc ≈ GBW/ACL . In order for the neuron to faithfully track with a +time constant τ , we require fc ≳ 1/(2πτ ) for every stage in the signal path. Closed-loop gains for each +of the op-amps are: ACL (U 1) = 1 because it is a unity-gain buffer, ACL (U 2) = Aself because it needs +to realize the self term gain, and ACL (U 3) ≈ 1 because it is a unity-gain summer. Assuming the same +op-amp design for U1, U2, and U3, and taking the worst case, + + max(1, Aself ) + τGBW = (82) + 2πGBW + +I.1.2 Slew rate bound +The slew-rate limits cap the maximum output slope of each op-amp stage: + • U1: activation buffer. |f˙µ | ≤ Lf Ah /τ , which gives τ ≥ (Lf Ah )/SRU1 . + + + 27 +Table 6: Estimated neuron time constants and conservative convergence times with Av = Ah = 1 V, + 1 +Lg = 1, Aself = 1 for representative amplifiers in literature. GBW bound τGBW = 2π GBW ; SR bound + Lg Av +τSR = SR (visible path). Overall τmin = max{τGBW , τSR }; we report Tconv = 10 τmin . + +CMOS Amplifier (ref.) SR (V/µs) GBW (MHz) τSR (ns) τGBW (ns) Tconv (ns) +Perez and Maloberti [36] 84.50 321.50 11.83 0.50 118.34 +Assaad and Silva-Martinez [37] 94.10 134.20 10.63 1.19 106.27 +Yen and Blalock [38] 202.00 10.70 4.95 14.87 148.74 +Naderi, Prakash, and Silva-Martinez [39] 1250.00 3600.00 0.80 0.04 8.00 +Schlögl and Zimmermann [40] 1650.00 2510.00 0.61 0.06 6.06 +Notes. (i) τSR values assume the visible path dominates the summer’s SR (low/moderate-β). If softmax dominates at U3 + in the high-β regime, multiply SR-limited values by κ = (β/2) (Ah /Av ) (with Ah = Av = 1 V, simply β/2). (ii) The + current-limit bound τI-limit = CAv /Imax is typically ≪ all reported values for C ∼ 50 fF and Imax ∼mA, so it is omitted + from the table but must still be respected in circuit sizing. + + + • U2: self-term. sµ = Aself fµ , so |ṡµ | = Aself |f˙µ | ≤ (Aself Lf Ah )/τ , which gives τ ≥ (Aself Lf Ah )/SRU2 . + • U3: internal state drive. The time-varying portion of the RC circuit drive dµ is a linear combina- + tion of fµ and gi , with coefficients that have a maximum magnitude of Aself . Using the bounds on + the slopes of those inputs, we get the following bound on |d˙µ | and subsequently the time constant + bound: + Aself Aself max(Lf Ah , Lg Av ) + |d˙µ | ≲ max{Lf Ah , Lg Av } ⇒ τ≥ (83) + τ SRU3 + +All together, the combined constraint is + + Lf Ah Aself Lf Ah Aself max(Lf Ah , Lg Av ) + τSR = max , , (84) + SRU1 SRU2 SRU3 + +I.1.3 Current / headroom limit +U3 must provide the current through R2 to charge C1 . The RC circuit dynamics dictate R2 C1 ḣµ = +−hµ + dµ , so the instantaneous current needed by U3 is + + dµ − h µ + IU3,out = = C1 ḣµ (85) + R2 + +We must respect |IU3,out | ≤ Imax,U3 . With |ḣµ | ≲ Ah /τ , + + C1 Ah + τI-limit ≥ (86) + Imax,U3 + +I.1.4 Combined bound on minimum time constant +Taken together, the minimum time constant must satisfy the bounds (82), (84), and (86): + + τmin ≥ max{τGBW , τSR , τI-limit } (87) + +I.2 Estimates of inference times with existing hardware +Under standard assumptions for DenseAMs (symmetric couplings and monotone activations), the Lya- +punov energy decreases monotonically and the dynamics converge without oscillations. The settling time +is therefore on the order of a few multiples of the largest neuronal time constant, which we bound by +amplifier non-idealities. In this section we take some representative examples of op-amps from literature +and estimate the inference speeds from reasonable and representative design parameters. + + + + + 28 +Minimum time constant. For illustration purposes, we choose three reasonable hardware constraints: + • Activation slopes. Take the slope of the visible activation to be Lg = 1, such as would occur in + a identity visible neuron activation. Take the worst-case (maximum) slope of the hidden activation + to be according to the softmax with fixed β, whose Jacobian is βG(f ) with ∥G(f )∥2 ≤ 12 , so a safe + global bound is Lf ≤ β2 . + • Signal swing. Use the voltage scaling invariance (see Appendix F) to rescale v, ξ, and β together + to pick a swing that is slew-rate friendly but well above component noise limits. Take both Av = + Ah = 1V . + + • Self-term gain. With row/column budgets, use Aself as a worst-case bound. +With those choices, the three lower bounds per neuron are: + + 1. GBW Bound: τGBW = max(1,A + 2πGBW + self ) 1 + = 2πGBW . + L A + 2. SR Bound: The U1/U2 path give τSR,vis = SR g v 1 + = SR µs. In the U3 (summer) path, equation (84) + has two cases. In the low-β regime where Lg Av ≥ Lf Ah , then U3 bound reduces to 1/SR µs. In + the high-β regime where Lf Ah = β/2 dominates, scale the slew-rate limited bound by β/2. + 3. Output Current Bound: In practice, this bound generally does not limit the op amp choice: + even with a large capacitor C = 50 fF, Av = 1V, Imax = 2mA, τI-limit ≈ 0.025ns, which is negligible + compared to the bounds from SR and GBW. +To quantify realistic inference speeds, Table 6 lists representative CMOS operational transconductance +amplifiers (OTAs)3 drawn from recent literature, together with their corresponding lower bounds on +neuronal time constants under the GBW and slew-rate limits. Even using conservative assumptions +with existing amplifier designs, the analysis shows that modern high-speed OTAs can achieve sub–10 ns +neuronal convergence times—corresponding to inference rates in the hundreds of megahertz. + + +J Connection between analog and canonical Energy Transformer +In this section we show that in the adiabatic limit, our Analog Energy Transformer (Analog ET) reduces +to the canonical Energy Transformer. Begin with the dynamics for the Analog Energy Transformer +implemented by our circuit designs. + + ∂E ⊤ ⊤ + τv v̇ = − = ξ attn f attn + ξ hopf f hopf + a − v (88) + ∂v + ∂E + τh ḣattn + = − attn = ξattn v + b − hattn (89) + ∂f + ∂E + τh ḣhopf = − hopf = ξhopf v + c − hhopf (90) + ∂f +Integrating out hidden neurons in the adiabatic limit where τh → 0, we see the relations + + hattn (v) = ξ attn v + b (91) + hopf hopf + h (v) = ξ v+c (92) + +which we can use to integrate out the hidden neuron activations as + + f attn (v) = softmax ξ attn v + b + + (93) + + f hopf (v) = ReLU ξ hopf v + c (94) + +Substituting into the visible dynamics: + ⊤ attn ⊤ + τv v̇ = ξ attn f (v) + ξ hopf f hopf (v) + a − v (95) + 3 Many high-speed CMOS “op-amps” are reported as OTAs (transconductors). In our neuron, these OTA cores operate + +in closed-loop (unity/non-inverting) configurations, so the literature SR and GBW directly constrain τ via Eqs. (82)–(84). + + + + 29 +We can ask ourselves, what scalar energy produces this ODE? We seek an energy Eeff (v) such that +τv v̇ = − ∂E + ∂v . Equivalently, + eff + + + + ⊤ attn ⊤ + ∇v Eeff (v) = v − a − ξ attn f (v) − ξ hopf f hopf (v) (96) + +We can construct Eeff (v) as a sum of three pieces whose gradients match each term Eeff (v) = Equad (v) + +Eattn (v) + Ehopf (v). By inspection we see that Equad (v) = 21 ∥v − a∥2 . + +Attention term. The energy function + 1 X + exp β ξ attn + + Eattn (v) = − log A v + bA (97) + β + A + +satisfies our requirement. We can see that by differentiating with respect to vi , we get + ∂Eattn X + =− softmax(ξ attn v + b)A · ξAi + attn + (98) + ∂vi + A + X + attn attn + =− ξAi fA (v) (99) + A + ⊤ attn +which yields our desired dynamics of ∇v Eattn (v) = − ξ attn f (v). + +Hopfield term. A simple way to achieve the desired dynamics is with a Hopfield-type energy function + X1 2 + Ehopf (v) = − ReLU ξ hopf + µ v + c µ (100) + µ + 2 + +whose derivative with respect to vi yields + ∂Ehopf X + hopf + =− ReLU ξ hopf + µ v + c µ · ξµi (101) + ∂vi µ + X hopf + =− ξµi fµhopf (v) (102) + µ + + ⊤ +which yields our desired dynamics of ∇v Ehopf (v) = − ξ hopf f hopf (v). + +Effective energy function of analog energy transformer. All together, the effective scalar energy +over the visible state v after integrating out hidden neurons is + 1 1 X X 1 2 + Eeff (v) = ∥v − a∥22 − log exp β ξ attn + A v + bA − ReLU ξ hopf + µ v + cµ (103) + |2 {z } β A µ + 2 + Equad | {z } | {z } + Eattn Ehopf + +This effective energy aligns with the canonical Energy Transformer’s energy function. Because our effec- +tive dynamics use hidden neurons, the energy function written in the main text reflects the contributions +of the hidden neurons. When τh ≪ τv , this regime converges to the behavior when the hidden neurons +are integrated out. Hence, the effective expressibility and behavior of our system is equivalent to that of +the original Energy Transformer. + In our model we omit the layer normalization activation that the original Energy Transformer applies +to the visible neurons. This keeps the circuit design simple, while still enabling models with high +expressibility. This choice does not modify the structure of the attention or the Hopfield parts of the +energy; only the self-energy of v differs. From a modeling perspective, layer normalization mainly +improves conditioning and learning of deep networks rather than changing the computational primitive +and expressibility. We empirically observe that the resulting models without layer normalization remain +expressive enough to solve the problems we present. In principle, a layer normalization-type visible +activation function could be implemented in analog hardware (e.g. by subtracting the mean voltage +and normalizing by an on-chip variance estimate), but this would add distracting complications to the +minimalist neuron and circuit designs we show in this paper. + + + 30 +
\ No newline at end of file diff --git a/ep_run/analyze.py b/ep_run/analyze.py new file mode 100644 index 0000000..9514773 --- /dev/null +++ b/ep_run/analyze.py @@ -0,0 +1,95 @@ +"""Load runs/*/log.jsonl, print comparison table, and save loss curves as ASCII/PNG.""" +import json +import sys +from pathlib import Path + +import numpy as np + + +def load_run(run_dir: Path): + log_path = run_dir / "log.jsonl" + if not log_path.exists(): + return None + steps, step_losses, evals = [], [], [] + for line in log_path.read_text().splitlines(): + rec = json.loads(line) + if rec.get("event") == "step": + steps.append(rec["iter"]) + step_losses.append(rec["train_loss"]) + elif rec.get("event") == "eval": + evals.append((rec["iter"], rec["train_loss"], rec["val_loss"])) + return { + "name": run_dir.name, + "steps": np.array(steps), + "step_losses": np.array(step_losses), + "evals": np.array(evals) if evals else np.zeros((0, 3)), + } + + +def ascii_plot(runs, key_idx=2, width=60, height=15, title="val loss"): + """key_idx: 1 = train loss (from eval), 2 = val loss (from eval).""" + lines = [title] + all_y = np.concatenate([r["evals"][:, key_idx] for r in runs if len(r["evals"]) > 0]) + all_x = np.concatenate([r["evals"][:, 0] for r in runs if len(r["evals"]) > 0]) + if len(all_y) == 0: + return "(no eval data)" + ymin, ymax = float(all_y.min()), float(all_y.max()) + xmin, xmax = float(all_x.min()), float(all_x.max()) + ymin -= 0.02 * (ymax - ymin + 1e-9) + ymax += 0.02 * (ymax - ymin + 1e-9) + grid = [[" "] * width for _ in range(height)] + markers = {0: "o", 1: "x", 2: "+", 3: "*"} + for i, r in enumerate(runs): + if len(r["evals"]) == 0: + continue + mk = markers.get(i, "#") + for x, _tl, vl in r["evals"]: + col = int((x - xmin) / (xmax - xmin + 1e-9) * (width - 1)) + row = height - 1 - int((vl - ymin) / (ymax - ymin + 1e-9) * (height - 1)) + row = max(0, min(height - 1, row)) + col = max(0, min(width - 1, col)) + grid[row][col] = mk + lines.append(f" y: [{ymin:.3f} .. {ymax:.3f}] x: [{int(xmin)} .. {int(xmax)}]") + for i, row in enumerate(grid): + lines.append(" |" + "".join(row) + "|") + lines.append(" +" + "-" * width + "+") + legend = " legend: " + " ".join( + f"{markers.get(i, '#')}={r['name']}" for i, r in enumerate(runs) + ) + lines.append(legend) + return "\n".join(lines) + + +def main(): + runs_dir = Path("runs") + run_names = sys.argv[1:] if len(sys.argv) > 1 else [ + "softmax_baseline", "sigmoid_b0", "sigmoid_blogn", + ] + runs = [] + for name in run_names: + r = load_run(runs_dir / name) + if r is None: + print(f"WARNING: {name}/log.jsonl missing") + continue + runs.append(r) + + print("\n=== final losses ===") + print(f"{'run':<20s} {'final train':>12s} {'final val':>10s} {'best val':>10s} {'iter':>6s}") + for r in runs: + if len(r["evals"]) == 0: + print(f"{r['name']:<20s} (no evals)") + continue + last = r["evals"][-1] + best_idx = int(np.argmin(r["evals"][:, 2])) + best = r["evals"][best_idx] + print( + f"{r['name']:<20s} {last[1]:>12.4f} {last[2]:>10.4f} " + f"{best[2]:>10.4f} {int(best[0]):>6d}" + ) + + print("\n" + ascii_plot(runs, key_idx=2, title="val loss vs iter")) + print("\n" + ascii_plot(runs, key_idx=1, title="eval train loss vs iter")) + + +if __name__ == "__main__": + main() diff --git a/ep_run/analyze_all.py b/ep_run/analyze_all.py new file mode 100644 index 0000000..f309190 --- /dev/null +++ b/ep_run/analyze_all.py @@ -0,0 +1,86 @@ +"""Analyze all runs in runs_local/ — produces a summary table sorted by val_loss. + +For each run: final val_loss, per-projection grad_cos breakdown, STE flags used. +""" +import json +from pathlib import Path + +def get_final(arm_dir): + p = arm_dir / "log.jsonl" + if not p.exists(): + return None + lines = p.read_text().strip().split("\n") + evals = [json.loads(l) for l in lines if l and json.loads(l).get("event") == "eval"] + if not evals: + return None + return evals[-1] + +def group_cos(gc): + groups = {} + for name, val in gc.items(): + if "q_proj" in name: key = "q" + elif "k_proj" in name: key = "k" + elif "v_proj" in name: key = "v" + elif "o_proj" in name: key = "o" + elif "mlp.fc" in name: key = "fc" + elif "mlp.proj" in name: key = "pr" + elif "head" in name: key = "hd" + else: key = "?" + groups.setdefault(key, []).append(val) + out = {} + for k, v in groups.items(): + valid = [x for x in v if x == x] + if valid: + out[k] = sum(valid) / len(valid) + return out + +def main(): + runs_dir = Path("runs_local") + rows = [] + for d in sorted(runs_dir.iterdir()): + if not d.is_dir(): + continue + ev = get_final(d) + if ev is None: + continue + cfg_path = d / "config.json" + cfg = json.loads(cfg_path.read_text()) if cfg_path.exists() else {} + gc = ev.get("grad_cos") or {} + groups = group_cos(gc) + valid = [v for v in gc.values() if v == v] + mean_cos = sum(valid) / len(valid) if valid else float("nan") + flags = [] + if cfg.get("ste_sigmoid"): flags.append("σSTE") + if cfg.get("ste_gelu"): flags.append("gSTE") + if cfg.get("ste_ln"): flags.append("lSTE") + if cfg.get("freeze_emb"): flags.append("frzE") + method = cfg.get("method", "?") + attn = cfg.get("attn_mode", "?") + nl = cfg.get("n_layer", "?") + rows.append({ + "name": d.name, + "method": method, + "attn": attn[:3], + "L": nl, + "flags": "+".join(flags) if flags else "-", + "val": ev.get("val_loss", float("nan")), + "μcos": mean_cos, + "groups": groups, + }) + + rows.sort(key=lambda r: r["val"] if r["val"] == r["val"] else 999) + + hdr = f"{'name':24s} {'meth':4s} {'attn':3s} {'L':>2s} {'flags':20s} {'val':>8s} {'μcos':>6s} {'hd':>5s} {'o':>5s} {'v':>5s} {'q':>5s} {'k':>5s} {'fc':>5s} {'pr':>5s}" + print(hdr) + print("-" * len(hdr)) + for r in rows: + g = r["groups"] + def fmt(k): + v = g.get(k, float("nan")) + return f"{v:>5.2f}" if v == v else " nan" + val_s = f"{r['val']:>8.4f}" if r["val"] == r["val"] else " nan" + cos_s = f"{r['μcos']:>6.3f}" if r["μcos"] == r["μcos"] else " nan" + print(f"{r['name']:24s} {r['method']:4s} {r['attn']:3s} {str(r['L']):>2s} {r['flags']:20s} {val_s} {cos_s} {fmt('hd')} {fmt('o')} {fmt('v')} {fmt('q')} {fmt('k')} {fmt('fc')} {fmt('pr')}") + +if __name__ == "__main__": + main() diff --git a/ep_run/analyze_ln_jacobian.py b/ep_run/analyze_ln_jacobian.py new file mode 100644 index 0000000..dcf2149 --- /dev/null +++ b/ep_run/analyze_ln_jacobian.py @@ -0,0 +1,166 @@ +"""Analyze LN Jacobian decomposition: how much does each component (scaling, mean-center, +radial removal) contribute to the gradient at each LN layer? + +Trains a small FA model for 250 steps, then on one diagnostic batch: +1. Forward with hooks to capture each LN's (x, z, sigma) +2. Backward to get g_tilde = dL/dz (gradient wrt LN output) +3. Decompose: true J_LN @ g_tilde vs center_scale vs projected vs identity(STE) +4. Report per-layer cosines and energy fractions + +Run for both softmax and sigmoid to explain why center_scale costs more on softmax. +""" +import json +import pickle +from pathlib import Path +import torch +import torch.nn as nn +import torch.nn.functional as F +from model_local import LocalGPT, LocalGPTConfig +from local_layers import initialize_dfa_targets +import numpy as np + + +def get_batch(data_dir, block_size, batch_size, device): + data = np.memmap(data_dir / "train.bin", dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - 1, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix]) + return x.to(device), y.to(device) + + +def analyze_one_config(attn_mode, device, data_dir): + """Train FA model for 250 steps, then analyze LN Jacobian on one batch.""" + torch.manual_seed(1337) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + + cfg = LocalGPTConfig( + block_size=64, vocab_size=meta["vocab_size"], + n_layer=4, n_head=4, n_embd=128, dropout=0.0, + attn_mode=attn_mode, method="fa", + ) + model = LocalGPT(cfg).to(device) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + model.train() + for step in range(250): + X, Y = get_batch(data_dir, cfg.block_size, 32, device) + _, loss = model(X, Y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Now diagnostic: hook into LN layers to capture forward quantities + ln_data = {} # name -> {x, z, sigma, g_tilde} + + def make_forward_hook(name): + def hook(module, input, output): + x = input[0].detach() + mu = x.mean(dim=-1, keepdim=True) + xc = x - mu + var = (xc * xc).mean(dim=-1, keepdim=True) + sigma = torch.sqrt(var + 1e-5) + z = xc / sigma + ln_data[name] = {"x": x, "z": z, "sigma": sigma} + output.retain_grad() + ln_data[name]["output_ref"] = output + return hook + + hooks = [] + for name, module in model.named_modules(): + if isinstance(module, nn.LayerNorm): + hooks.append(module.register_forward_hook(make_forward_hook(name))) + + # Forward + backward on diagnostic batch + model.eval() + X, Y = get_batch(data_dir, cfg.block_size, 32, device) + logits, loss = model(X, Y) + loss.backward() + + # Collect g_tilde for each LN + for name in ln_data: + out_ref = ln_data[name]["output_ref"] + if out_ref.grad is not None: + ln_data[name]["g_tilde"] = out_ref.grad.detach() + + for h in hooks: + h.remove() + + # Analyze decomposition + results = {} + for name, d in ln_data.items(): + if "g_tilde" not in d: + continue + g = d["g_tilde"] # (B, T, dim) + z = d["z"] + sigma = d["sigma"] + dim = g.shape[-1] + + # True LN Jacobian action: g_x = (1/sigma) * (g - mean(g) - z*mean(g*z)) + g_mean = g.mean(dim=-1, keepdim=True) + gz_mean = (g * z).mean(dim=-1, keepdim=True) + + g_true = (g - g_mean - z * gz_mean) / sigma # full LN backward + g_center = (g - g_mean) / sigma # center_scale only + g_ste = g # identity STE + + # Energy fractions: what fraction of ||g||^2 is in each removed component? + g_norm_sq = (g * g).sum(-1).mean() + mean_component = g_mean.expand_as(g) + radial_component = z * gz_mean + r_mean = (mean_component * mean_component).sum(-1).mean() / (g_norm_sq + 1e-12) + r_radial = (radial_component * radial_component).sum(-1).mean() / (g_norm_sq + 1e-12) + + # Cosines: how well does each surrogate match the true LN backward? + def batch_cos(a, b): + a_flat = a.reshape(-1, dim) + b_flat = b.reshape(-1, dim) + cos = F.cosine_similarity(a_flat, b_flat, dim=-1) + return cos.mean().item() + + cos_center = batch_cos(g_center, g_true) + cos_ste = batch_cos(g_ste, g_true) + cos_center_to_ste = batch_cos(g_center, g_ste) + + # Sigma statistics + sigma_mean = sigma.mean().item() + sigma_std = sigma.std().item() + + results[name] = { + "r_mean": r_mean.item(), + "r_radial": r_radial.item(), + "sigma_mean": sigma_mean, + "sigma_std": sigma_std, + "cos_center_vs_true": cos_center, + "cos_ste_vs_true": cos_ste, + } + + return results + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + data_dir = Path("data/shakespeare_char") + + for attn in ["softmax", "sigmoid"]: + print(f"\n{'='*60}") + print(f" Attention: {attn}") + print(f"{'='*60}") + results = analyze_one_config(attn, device, data_dir) + print(f"{'name':30s} {'r_mean':>8s} {'r_rad':>8s} {'σ_μ':>8s} {'cos_c/t':>8s} {'cos_s/t':>8s}") + print("-" * 80) + for name, r in sorted(results.items()): + print(f"{name:30s} {r['r_mean']:8.4f} {r['r_radial']:8.4f} " + f"{r['sigma_mean']:8.3f} {r['cos_center_vs_true']:8.4f} {r['cos_ste_vs_true']:8.4f}") + + # Summary + r_means = [r["r_mean"] for r in results.values()] + r_rads = [r["r_radial"] for r in results.values()] + cos_cs = [r["cos_center_vs_true"] for r in results.values()] + cos_ss = [r["cos_ste_vs_true"] for r in results.values()] + print(f"\n AVG r_mean={sum(r_means)/len(r_means):.4f} r_radial={sum(r_rads)/len(r_rads):.4f} " + f"cos_center={sum(cos_cs)/len(cos_cs):.4f} cos_ste={sum(cos_ss)/len(cos_ss):.4f}") + + +if __name__ == "__main__": + main() diff --git a/ep_run/analyze_softmax_jacobian.py b/ep_run/analyze_softmax_jacobian.py new file mode 100644 index 0000000..91ebd70 --- /dev/null +++ b/ep_run/analyze_softmax_jacobian.py @@ -0,0 +1,168 @@ +"""Analyze softmax attention Jacobian: decompose into diagonal (local) vs off-diagonal (lateral). + +The softmax Jacobian J = diag(A) - AA^T acts on gradient g as: + g_S = A ⊙ g - A * (A^T g) (full, has lateral sum) + g_S_diag = A ⊙ (1-A) ⊙ g (diagonal-only, element-wise, same formula as sigmoid) + g_S_ste = g (identity STE) + +This script measures: +1. How much energy is in diagonal vs off-diagonal components +2. Cosine between full vs diagonal-only vs STE on real FA training data +3. Per-head, per-layer breakdown +4. Whether removing the lateral sum is catastrophic or tolerable +""" +import pickle +from pathlib import Path +import torch +import torch.nn as nn +import torch.nn.functional as F +from model_local import LocalGPT, LocalGPTConfig +import numpy as np + + +def get_batch(data_dir, block_size, batch_size, device): + data = np.memmap(data_dir / "train.bin", dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - 1, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix]) + return x.to(device), y.to(device) + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + data_dir = Path("data/shakespeare_char") + torch.manual_seed(1337) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + + # Train a softmax FA model for 500 steps to get meaningful attention patterns + cfg = LocalGPTConfig( + block_size=64, vocab_size=meta["vocab_size"], + n_layer=4, n_head=4, n_embd=128, dropout=0.0, + attn_mode="softmax", method="fa", + ) + model = LocalGPT(cfg).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + model.train() + for step in range(500): + X, Y = get_batch(data_dir, cfg.block_size, 32, device) + _, loss = model(X, Y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + print(f"Trained 500 steps, final loss: {loss.item():.3f}") + + # Hook into attention forward to capture scores and attention weights + attn_data = {} + + def make_attn_hook(name, module): + original_forward = module.forward + + def hooked_forward(x): + B, T, C = x.shape + q = module.q_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2) + k = module.k_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2) + v = module.v_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2) + + scores = (q @ k.transpose(-2, -1)) * (module.head_dim ** -0.5) + mask = module.causal_mask[:T, :T] + scores = scores.masked_fill(~mask, float("-inf")) + + attn = F.softmax(scores, dim=-1) + attn_data[name] = { + "scores": scores.detach(), + "attn": attn.detach(), + } + + # Need grad wrt attention output for Jacobian analysis + attn_for_grad = attn.clone().requires_grad_(True) + out = (attn_for_grad @ v).transpose(1, 2).contiguous().view(B, T, C) + out = module.resid_drop(module.o_proj(out)) + + attn_data[name]["attn_for_grad"] = attn_for_grad + return out + + module.forward = hooked_forward + return module + + # Install hooks + for name, module in model.named_modules(): + if hasattr(module, "q_proj") and hasattr(module, "k_proj"): + make_attn_hook(name, module) + + # Forward + backward on diagnostic batch + model.eval() + X, Y = get_batch(data_dir, cfg.block_size, 32, device) + logits, loss = model(X, Y) + loss.backward() + + # Analyze each attention layer + print(f"\n{'layer':30s} {'A_mean':>8s} {'A_entropy':>10s} {'r_diag':>8s} {'r_offdiag':>10s} " + f"{'cos_diag':>9s} {'cos_ste':>8s}") + print("-" * 100) + + for name, d in sorted(attn_data.items()): + A = d["attn"] # (B, n_head, T, T) + attn_ref = d.get("attn_for_grad") + + if attn_ref is None or attn_ref.grad is None: + print(f"{name:30s} (no grad captured)") + continue + + g = attn_ref.grad.detach() # (B, n_head, T, T) = dL/dA + B_size, n_head, T, _ = A.shape + + # Per-head analysis + for h in range(n_head): + A_h = A[:, h, :, :] # (B, T, T) + g_h = g[:, h, :, :] # (B, T, T) + + # Full softmax backward: g_S = A * (g - A @ g sum along last dim) + Ag_sum = (A_h * g_h).sum(dim=-1, keepdim=True) # (B, T, 1) = sum_j A_j g_j per query + g_full = A_h * (g_h - Ag_sum) # (B, T, T) + + # Diagonal-only (element-wise, sigmoid-like): g_diag = A*(1-A)*g + g_diag = A_h * (1 - A_h) * g_h # (B, T, T) + + # STE: g_ste = g + g_ste = g_h + + # Energy fractions + g_full_norm = (g_full * g_full).sum((-2, -1)).mean() + g_diag_norm = (g_diag * g_diag).sum((-2, -1)).mean() + diff_norm = ((g_full - g_diag) * (g_full - g_diag)).sum((-2, -1)).mean() + + # Cosines (flatten per-sample) + def cos(a, b): + af = a.reshape(B_size, -1) + bf = b.reshape(B_size, -1) + return F.cosine_similarity(af, bf, dim=-1).mean().item() + + cos_diag = cos(g_diag, g_full) + cos_ste = cos(g_ste, g_full) + + # Attention statistics + # Mask out -inf positions for stats + valid_mask = A_h > 0 + A_valid = A_h[valid_mask] + A_mean = A_valid.mean().item() + + # Entropy per query row + entropy = -(A_h * (A_h + 1e-10).log()).sum(-1).mean().item() + + r_diag = g_diag_norm / (g_full_norm + 1e-12) + + print(f"{name}.head{h:d} " + f" {A_mean:8.4f} {entropy:10.3f} {r_diag.item():8.3f} " + f"{(1-r_diag).item():10.3f} {cos_diag:9.4f} {cos_ste:8.4f}") + + # Summary + print(f"\nKey: r_diag = ||g_diag||^2 / ||g_full||^2 (energy in diagonal/element-wise part)") + print(f" cos_diag = cosine(diagonal-only, full softmax backward)") + print(f" cos_ste = cosine(identity STE, full softmax backward)") + print(f"\nIf cos_diag ≈ 1: diagonal-only (sigmoid-like) approximation is good → lateral sum not needed") + print(f"If cos_diag << 1: off-diagonal (lateral sum) is critical → need to keep or find local surrogate") + + +if __name__ == "__main__": + main() diff --git a/ep_run/anderson_control.py b/ep_run/anderson_control.py new file mode 100644 index 0000000..10147c3 --- /dev/null +++ b/ep_run/anderson_control.py @@ -0,0 +1,67 @@ +import torch, pickle, math, numpy as np +from pathlib import Path +from scipy.sparse.linalg import LinearOperator, eigs +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=2; T=256; eps=0.1 +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +def load_blk(path): + ck=torch.load(path,map_location=dev) + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True + with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + return blk, ck.get('best','?') +def gmap(blk,xin,z): + with torch.no_grad(): return z+eps*blk.force(z,xin).detach() +def plain(blk,xin,z0,steps=300): + z=z0.clone() + for _ in range(steps): z=gmap(blk,xin,z) + return ((gmap(blk,xin,z)-z).norm()/(z.norm()+1e-9)).item() +def anderson(blk,xin,z0,m=6,max_iter=400,tol=1e-7,lam=1e-4): + Bs,d=z0.shape[0],z0[0].numel() + X=torch.zeros(Bs,m,d,device=dev); Fb=torch.zeros(Bs,m,d,device=dev) + X[:,0]=z0.reshape(Bs,d); Fb[:,0]=gmap(blk,xin,z0).reshape(Bs,d) + X[:,1]=Fb[:,0]; Fb[:,1]=gmap(blk,xin,X[:,1].view_as(z0)).reshape(Bs,d) + Hm=torch.zeros(Bs,m+1,m+1,device=dev); Hm[:,0,1:]=1; Hm[:,1:,0]=1 + yv=torch.zeros(Bs,m+1,1,device=dev); yv[:,0]=1 + r=1.0; best_r=9.0; z_best=z0.clone() + for k in range(2,max_iter): + n=min(k,m); Gm=Fb[:,:n]-X[:,:n] + Hm[:,1:n+1,1:n+1]=torch.bmm(Gm,Gm.transpose(1,2))+lam*torch.eye(n,device=dev)[None] + alpha=torch.linalg.solve(Hm[:,:n+1,:n+1],yv[:,:n+1])[:,1:n+1,0] + X[:,k%m]=torch.bmm(alpha[:,None],Fb[:,:n])[:,0] + Fb[:,k%m]=gmap(blk,xin,X[:,k%m].view_as(z0)).reshape(Bs,d) + r=((Fb[:,k%m]-X[:,k%m]).norm()/(Fb[:,k%m].norm()+1e-9)).item() + if r<best_r: best_r=r; z_best=X[:,k%m].view_as(z0).clone() + if r<tol or not math.isfinite(r): break + return best_r,k+1,z_best +def eig_at(blk,xin,z,k=6,hrel=1e-3): + z=z.detach(); F0=blk.force(z,xin).detach(); zn=z.norm().item(); h=hrel*zn; shp=z.shape; N=z.numel() + def Jv(vt): + nv=vt.norm().item() + if nv<1e-20: return torch.zeros_like(vt) + return (blk.force(z+h*(vt/nv),xin).detach()-F0)/h*nv + def matvec(v): + vt=torch.from_numpy(np.ascontiguousarray(v).astype('float32')).reshape(shp).to(dev) + return (vt+eps*Jv(vt)).double().cpu().numpy().reshape(-1) + op=LinearOperator((N,N),matvec=matvec,dtype='float64') + vals=eigs(op,k=k,which='LM',return_eigenvectors=False,maxiter=4000,tol=1e-5) + return F0.norm().item()/(zn+1e-9), sorted(vals,key=lambda x:-abs(x)) +for tag,path in [("s2000 (healthy)","runs/redx_traj/s2000.pt"), + ("s3200 (blew@2.74)","runs/redx_traj/s3200.pt")]: + blk,best=load_blk(path); xin=blk.embed(idx).detach() + with torch.no_grad(): + pr=plain(blk,xin,xin.clone(),300) + ar,ak,zst=anderson(blk,xin,xin.clone(),max_iter=400) + print(f"=== {tag} best={best} ===") + print(f" plain relax(300) res={pr:.3e} | Anderson best_res={ar:.3e} in {ak} iters") + if ar<2e-3: + g,vals=eig_at(blk,xin,zst) + print(f" -> Anderson FOUND a root (force g={g:.4f}); eigenvalues of M=I+eps*J at the root:") + for lam in vals[:4]: + mu=(lam-1)/eps + print(f" |lam|={abs(lam):.5f} mu={mu.real:+.4f}{mu.imag:+.4f}j [{'UNSTABLE' if abs(lam)>1+1e-4 else 'STABLE'}{' rot' if abs(lam.imag)>1e-3 else ' real'}]") + else: + print(f" -> Anderson did NOT converge to a root (best_res={ar:.3e}) => no reachable fixed point") +print("=== DONE === key: s3200 root + ReMu<0 = Euler-artifact(integration fixes it); root + ReMu>0 = unstable fixed pt(true instab); no root = true instab") diff --git a/ep_run/asym_probe.py b/ep_run/asym_probe.py new file mode 100644 index 0000000..1b61354 --- /dev/null +++ b/ep_run/asym_probe.py @@ -0,0 +1,922 @@ +"""Matrix-free asymmetry probe for the equilibrium-transformer block Jacobian. + +The state Jacobian J = dF/dz is never materialized. We estimate the growth of +T = (S + mu I)^-1 A, where S=(J+J^T)/2 and A=(J-J^T)/2, using autograd JVP/VJP +products at the relaxed fixed point. +""" +import argparse +import glob +import math +import os +import pickle +import time +import warnings + +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", message=".*cuBLAS.*", category=UserWarning) +warnings.filterwarnings("ignore", message=".*CUBLAS.*", category=UserWarning) + +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.sparse.linalg import LinearOperator, gmres, minres + +import lt_ep_train as L +from lt_ep_train import EQBlock, bptt_step, ce, ep_step, relax + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt", default="runs/ep_clean.pt") + ap.add_argument("--data", default="data/tinystories_bpe") + ap.add_argument("--gelu", default="erf") + ap.add_argument("--C", type=int, default=512) + ap.add_argument("--H", type=int, default=16) + ap.add_argument("--Mm", type=int, default=256) + ap.add_argument("--T", type=int, default=256) + ap.add_argument("--B", type=int, default=8) + ap.add_argument("--T1", type=int, default=150) + ap.add_argument("--T2", type=int, default=20) + ap.add_argument("--eps", type=float, default=0.1) + ap.add_argument("--beta", type=float, default=0.02) + ap.add_argument("--t1max", type=int, default=2000) + ap.add_argument("--relax-chunk", type=int, default=50) + ap.add_argument("--res-est", type=float, default=1e-4) + ap.add_argument("--t2sel", type=int, default=40) + ap.add_argument("--holo", type=int, default=2) + ap.add_argument("--hr", type=float, default=0.02) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--trace-probes", type=int, default=4) + ap.add_argument("--mu-scale", type=float, default=1e-3) + ap.add_argument("--mu", type=float, default=-1.0, help="override mu; negative means estimate from trace") + ap.add_argument("--solve-iters", type=int, default=80) + ap.add_argument("--solve-tol", type=float, default=1e-5) + ap.add_argument("--adjoint-iters", type=int, default=200) + ap.add_argument("--adjoint-tol", type=float, default=1e-5) + ap.add_argument("--adjoint-mu", type=float, default=1e-4, help="fallback Tikhonov mu for J^T+muI if GMRES stalls") + ap.add_argument("--rho-iters", type=int, default=20) + ap.add_argument("--rho-restarts", type=int, default=3) + ap.add_argument("--sigma-iters", type=int, default=8) + ap.add_argument("--sigma-restarts", type=int, default=2) + ap.add_argument("--arnoldi-k", type=int, default=12) + ap.add_argument("--skiprho", action=argparse.BooleanOptionalAction, default=True, + help="skip rho/sigma spectral probes and run only exact-adjoint gradient comparison") + ap.add_argument("--diag", action="store_true", help="run EP/exact-adjoint diagnostic suite and exit") + ap.add_argument("--noplot", action="store_true", help=argparse.SUPPRESS) + ap.add_argument("--lr", type=float, default=None, help=argparse.SUPPRESS) + ap.add_argument("--device", default="cuda", choices=["cuda", "cpu"]) + ap.add_argument("--tf32", action="store_true") + return ap.parse_args() + + +def resolve_ckpt_path(path): + p = Path(path) + if p.is_absolute(): + return str(p) + cwd_path = Path.cwd() / p + if cwd_path.exists(): + return str(cwd_path) + return str(Path(__file__).resolve().parent / p) + + +def require_cuda_if_requested(device): + if device != "cuda": + return + visible = os.environ.get("CUDA_VISIBLE_DEVICES") + ok = torch.cuda.is_available() and torch.cuda.device_count() > 0 + if ok: + torch.cuda.set_device(0) + return + print("ERROR: CUDA unavailable; requested GPU0 run cannot start.", flush=True) + print(f"CUDA_VISIBLE_DEVICES={visible!r}", flush=True) + print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda}", flush=True) + print(f"torch.cuda.is_available()={torch.cuda.is_available()} device_count={torch.cuda.device_count()}", flush=True) + nodes = glob.glob("/dev/nvidia*") + required = ["/dev/nvidiactl", "/dev/nvidia-uvm", "/dev/nvidia0"] + missing = [p for p in required if not os.path.exists(p)] + print(f"/dev/nvidia*={' '.join(nodes) if nodes else 'MISSING'}", flush=True) + print(f"missing CUDA device nodes={' '.join(missing) if missing else 'none'}", flush=True) + raise SystemExit(2) + + +def build_block(cfg, dev): + # Same construction and checkpoint-copy path as resreg_probe.py. + L.dev = dev + L.DD = Path(cfg.data) + L.vocab = pickle.load(open(L.DD / "meta.pkl", "rb"))["vocab_size"] + torch.manual_seed(cfg.seed) + blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=1.0, attn_mode="thick") + blk.qknorm = True + blk.fnoise = 0.0 + blk._cstep = None + blk.navg = 1 + blk.li_avg = 0 + blk.track = True + blk.nbrake = 0.0 + blk.gelu = cfg.gelu + ck = torch.load(cfg.ckpt, map_location=dev) + with torch.no_grad(): + for p, w in zip(blk.allp, ck["allp"]): + p.copy_(w.to(dev)) + return blk, ck + + +@torch.no_grad() +def residuals(blk, z, xin, eps): + z1 = relax(blk, z, xin, 1, eps) + zn = z.norm().item() + 1e-12 + step_rel = (z1 - z).norm().item() / zn + force_rel = blk.tforce(z, xin).norm().item() / zn + return step_rel, force_rel + + +def relax_to_fixed_point(blk, xin, cfg): + z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) + step_rel, force_rel = residuals(blk, z, xin, cfg.eps) + steps = cfg.T1 + while steps < cfg.t1max and step_rel > cfg.res_est: + chunk = min(cfg.relax_chunk, cfg.t1max - steps) + z = relax(blk, z, xin, chunk, cfg.eps) + steps += chunk + step_rel, force_rel = residuals(blk, z, xin, cfg.eps) + print(f"relax steps={steps:4d} step_res={step_rel:.3e} force_res={force_rel:.3e}", flush=True) + return z.detach(), steps, step_rel, force_rel + + +def dot(a, b): + return torch.dot(a.reshape(-1), b.reshape(-1)) + + +def norm(a): + return torch.linalg.vector_norm(a.reshape(-1)) + + +def block_param_list(blk): + if hasattr(blk.block, "parameters"): + return list(blk.block.parameters()) + return list(blk.block) + + +def flat_grad_by_param_id(grads, params): + flat = [] + for p in params: + g = grads.get(id(p)) if grads is not None else None + if g is None: + g = torch.zeros_like(p, device="cpu", dtype=torch.float64) + else: + g = g.detach().to(device="cpu", dtype=torch.float64) + flat.append(g.reshape(-1)) + return torch.cat(flat) + + +def set_param_requires_grad(blk, value): + for p in blk.allp: + p.requires_grad_(value) + + +def cos(a, b): + return (torch.dot(a, b) / (norm(a) * norm(b) + 1e-20)).item() + + +def rel_diff(a, b): + return (norm(a - b) / (norm(b) + 1e-20)).item() + + +def unit_rand(shape, dev, dtype): + v = torch.randn(shape, device=dev, dtype=dtype) + return v / (norm(v) + 1e-30) + + +@dataclass +class SolveLog: + residuals: list = field(default_factory=list) + infos: list = field(default_factory=list) + iters: list = field(default_factory=list) + + def add(self, rel_res, info, nit): + self.residuals.append(float(rel_res)) + self.infos.append(int(info)) + self.iters.append(int(nit)) + + def summary(self): + if not self.residuals: + return "solve residuals: none" + r = np.asarray(self.residuals, dtype=np.float64) + it = np.asarray(self.iters, dtype=np.int64) + bad = sum(1 for x in self.infos if x != 0) + return ( + f"solve residuals: count={len(r)} min={r.min():.3e} " + f"median={np.median(r):.3e} max={r.max():.3e} " + f"iters median={np.median(it):.0f} max={it.max()} nonzero_info={bad}" + ) + + +class Operators: + def __init__(self, blk, zstar, xin, cfg, mu): + self.blk = blk + self.zstar = zstar.detach() + self.xin = xin.detach() + self.shape = tuple(zstar.shape) + self.n = zstar.numel() + self.dev = zstar.device + self.dtype = zstar.dtype + self.cfg = cfg + self.mu = float(mu) + self.solve_log = SolveLog() + + def f(self, z): + return self.blk.tforce(z, self.xin) + + def jv(self, v): + with torch.enable_grad(): + _, out = torch.autograd.functional.jvp( + self.f, self.zstar, v.contiguous(), create_graph=False, strict=False + ) + return out.detach() + + def jtv(self, v): + with torch.enable_grad(): + z = self.zstar.detach().requires_grad_(True) + fz = self.f(z) + (g,) = torch.autograd.grad(fz, z, grad_outputs=v.contiguous(), create_graph=False, retain_graph=False) + return g.detach() + + def sv(self, v): + jv = self.jv(v) + jtv = self.jtv(v) + return 0.5 * (jv + jtv) + + def av(self, v): + jv = self.jv(v) + jtv = self.jtv(v) + return 0.5 * (jv - jtv) + + def smu(self, v, mu=None): + m = self.mu if mu is None else float(mu) + return self.sv(v) + m * v + + def _from_numpy(self, x): + x = np.asarray(x, dtype=np.float32) + return torch.from_numpy(x).to(device=self.dev, dtype=self.dtype).view(self.shape) + + def _to_numpy(self, x): + return x.detach().reshape(-1).float().cpu().numpy() + + def solve_s(self, rhs, mu=None, tag=""): + m = self.mu if mu is None else float(mu) + b = self._to_numpy(rhs) + counter = {"n": 0} + + def matvec(x_np): + x = self._from_numpy(x_np) + y = self.smu(x, m) + return self._to_numpy(y) + + def cb(_x): + counter["n"] += 1 + + Aop = LinearOperator((self.n, self.n), matvec=matvec, dtype=np.dtype("float32")) + x_np, info = minres(Aop, b, rtol=self.cfg.solve_tol, maxiter=self.cfg.solve_iters, callback=cb, check=False) + x = self._from_numpy(x_np).detach() + rel = (norm(self.smu(x, m) - rhs) / (norm(rhs) + 1e-30)).item() + self.solve_log.add(rel, info, counter["n"]) + if tag: + print(f"solve {tag}: mu={m:.3e} rel_res={rel:.3e} iters={counter['n']} info={info}", flush=True) + return x, rel, info, counter["n"] + + def solve_jt_gmres(self, rhs, tol, maxiter, mu=0.0, tag="adjoint"): + m = float(mu) + b = self._to_numpy(rhs) + counter = {"n": 0} + restart = max(1, min(50, int(maxiter))) + + def matvec(x_np): + x = self._from_numpy(x_np) + y = self.jtv(x) + if m != 0.0: + y = y + m * x + return self._to_numpy(y) + + def cb(_arg): + counter["n"] += 1 + + Aop = LinearOperator((self.n, self.n), matvec=matvec, dtype=np.dtype("float32")) + try: + x_np, info = gmres( + Aop, + b, + rtol=tol, + atol=0.0, + restart=restart, + maxiter=int(maxiter), + callback=cb, + callback_type="legacy", + ) + except TypeError: + x_np, info = gmres(Aop, b, tol=tol, restart=restart, maxiter=int(maxiter), callback=cb) + x = self._from_numpy(x_np).detach() + rel = (norm(self.jtv(x) + m * x - rhs) / (norm(rhs) + 1e-30)).item() + print(f"GMRES {tag}: mu={m:.3e} rel_res={rel:.3e} iters={counter['n']} info={info}", flush=True) + return x, rel, info, counter["n"] + + def t(self, v): + rhs = self.av(v) + x, _, _, _ = self.solve_s(rhs) + return x + + def tt(self, u): + y, _, _, _ = self.solve_s(u) + return -self.av(y) + + +def estimate_trace_s(op, probes): + vals = [] + for i in range(probes): + r = torch.randint(0, 2, op.shape, device=op.dev, dtype=torch.int8).to(op.dtype) + r = r.mul_(2).sub_(1) + sr = op.sv(r) + vals.append((dot(r, sr) / op.n).item()) + print(f"trace probe {i}: tr(S)/n={vals[-1]:+.6e}", flush=True) + return float(np.mean(vals)), float(np.std(vals) if len(vals) > 1 else 0.0) + + +def sensitivity_probe(op, mu): + v = unit_rand(op.shape, op.dev, op.dtype) + rhs = op.av(v) + xb, rb, _, _ = op.solve_s(rhs, mu=mu, tag="sensitivity/base") + rows = [] + for scale in (0.1, 10.0): + ms = max(mu * scale, 0.0) + xa, ra, _, _ = op.solve_s(rhs, mu=ms, tag=f"sensitivity/mu_x{scale:g}") + rel_dx = (norm(xa - xb) / (norm(xb) + 1e-30)).item() + rows.append((scale, ms, rel_dx, ra)) + print( + "solve sensitivity: " + + " ".join(f"mu_x{scale:g}: rel_dx={dx:.3e} rel_res={rr:.3e}" for scale, _, dx, rr in rows), + flush=True, + ) + return rb, rows + + +def power_rho(op, cfg): + best = 0.0 + best_hist = None + for r in range(cfg.rho_restarts): + v = unit_rand(op.shape, op.dev, op.dtype) + hist = [] + for i in range(cfg.rho_iters): + w = op.t(v) + growth = norm(w).item() + rq = (dot(v, w) / (dot(v, v) + 1e-30)).item() + hist.append((growth, rq)) + if growth <= 1e-30 or not math.isfinite(growth): + break + v = (w / growth).detach() + print(f"rho restart={r} iter={i + 1:02d} growth={growth:.6e} rayleigh={rq:+.6e}", flush=True) + if hist and hist[-1][0] > best: + best = hist[-1][0] + best_hist = hist + if best_hist: + trend = " ".join(f"{g:.3g}" for g, _ in best_hist[-min(6, len(best_hist)):]) + rtrend = " ".join(f"{rq:+.3g}" for _, rq in best_hist[-min(6, len(best_hist)):]) + print(f"rho power trend last={trend}", flush=True) + print(f"rho Rayleigh trend last={rtrend}", flush=True) + return best + + +def arnoldi_rho(op, k): + if k <= 0: + return None + q = unit_rand(op.shape, op.dev, op.dtype) + Q = [q] + H = np.zeros((k + 1, k), dtype=np.float64) + m = 0 + for j in range(k): + w = op.t(Q[j]) + for i in range(j + 1): + hij = dot(Q[i], w).item() + H[i, j] = hij + w = w - hij * Q[i] + hnext = norm(w).item() + H[j + 1, j] = hnext + m = j + 1 + print(f"arnoldi iter={j + 1:02d} h_next={hnext:.6e}", flush=True) + if hnext < 1e-12: + break + if j + 1 < k: + Q.append((w / hnext).detach()) + eig = np.linalg.eigvals(H[:m, :m]) + rho = float(np.max(np.abs(eig))) if eig.size else float("nan") + print(f"rho Arnoldi(k={m})={rho:.6e}", flush=True) + return rho + + +def power_sigma(op, cfg): + best = 0.0 + for r in range(cfg.sigma_restarts): + v = unit_rand(op.shape, op.dev, op.dtype) + sigma = 0.0 + for i in range(cfg.sigma_iters): + u = op.t(v) + sigma = norm(u).item() + w = op.tt(u) + wn = norm(w).item() + if wn <= 1e-30 or not math.isfinite(wn): + break + v = (w / wn).detach() + print(f"sigma restart={r} iter={i + 1:02d} sigma={sigma:.6e}", flush=True) + best = max(best, sigma) + return best + + +def ce_state_grad(blk, zstar, y): + with torch.enable_grad(): + z = zstar.detach().requires_grad_(True) + loss = ce(blk, z, y) + (ell,) = torch.autograd.grad(loss, z) + return ell.detach(), float(loss.detach()) + + +def solve_exact_adjoint(op, ell, cfg): + rhs = -ell.detach() + lam, rel, info, nit = op.solve_jt_gmres(rhs, cfg.adjoint_tol, cfg.adjoint_iters, mu=0.0, tag="J^T lambda=-ell") + stalled = (info != 0) or (not math.isfinite(rel)) or (rel > max(10.0 * cfg.adjoint_tol, 1e-4)) + mu_used = 0.0 + if stalled: + mu_used = max(float(cfg.adjoint_mu), 1e-8) + print(f"GMRES stalled; retrying exact-adjoint solve with Tikhonov J^T+muI, mu={mu_used:.3e}", flush=True) + lam, rel, info, nit = op.solve_jt_gmres( + rhs, cfg.adjoint_tol, cfg.adjoint_iters, mu=mu_used, tag="(J^T+muI) lambda=-ell" + ) + return lam.detach(), rel, info, nit, mu_used + + +def exact_transpose_grad(blk, idx, zstar, xin0, lam, params): + for p in blk.allp: + p.requires_grad_(True) + with torch.enable_grad(): + # Value stays at the relaxed clamp xin0, while tok/pos receive the same clamp-gradient path as the trainer. + xin = xin0 + (blk.embed(idx) - blk.embed(idx).detach()) + force = blk.tforce(zstar.detach(), xin) + grads = torch.autograd.grad((force * lam.detach()).sum(), params, allow_unused=True) + return {id(p): g for p, g in zip(params, grads)} + + +def run_ep_step_flat(blk, idx, y, cfg, params, *, beta=None, holo=None, hr=None, t2sel=None, track=None, T2=None): + saved_track = getattr(blk, "track", None) + if track is not None: + blk.track = bool(track) + try: + set_param_requires_grad(blk, True) + # Mirrors lt_ep_train.ep_step: + # (blk, idx, y, T1, T2, eps, beta, jacreg, holo, hr, t1max, res_est, t2sel, corr_every, res_gate, resreg). + grads, ep_res = ep_step( + blk, + idx, + y, + cfg.T1, + cfg.T2 if T2 is None else int(T2), + cfg.eps, + cfg.beta if beta is None else float(beta), + 0.0, + cfg.holo if holo is None else int(holo), + cfg.hr if hr is None else float(hr), + cfg.t1max, + cfg.res_est, + cfg.t2sel if t2sel is None else int(t2sel), + 1, + 0.0, + 0.0, + ) + return flat_grad_by_param_id(grads, params), float(ep_res) + finally: + if track is not None and saved_track is not None: + blk.track = saved_track + + +@torch.no_grad() +def fixed_point_step_abs(blk, zstar, xin, eps): + return (relax(blk, zstar, xin, 1, eps) - zstar).norm().item() + + +def exact_reference_for_batch(blk, idx, y, cfg, label, compute_bptt=True): + print(f"--- exact reference: {label} ---", flush=True) + xin0 = blk.embed(idx).detach() + zstar, steps, step_res, force_res = relax_to_fixed_point(blk, xin0, cfg) + step_abs = fixed_point_step_abs(blk, zstar, xin0, cfg.eps) + print( + f"{label}: z* residual step_abs={step_abs:.6e} step_rel={step_res:.6e} " + f"force_rel={force_res:.6e} relax_steps={steps}", + flush=True, + ) + if step_res > cfg.res_est: + print(f"{label}: WARNING step_res={step_res:.3e} > res_est={cfg.res_est:.3e}", flush=True) + + set_param_requires_grad(blk, False) + op = Operators(blk, zstar, xin0, cfg, mu=0.0) + ell, ce_loss = ce_state_grad(blk, zstar, y) + print(f"{label}: CE(z*)={ce_loss:.6f} ||ell||={norm(ell).item():.6e}", flush=True) + lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg) + print( + f"{label}: adjoint residual={gmres_rel:.3e} iters={gmres_iters} info={gmres_info} " + f"tikhonov_mu={adj_mu:.3e}", + flush=True, + ) + + params = block_param_list(blk) + gt = flat_grad_by_param_id(exact_transpose_grad(blk, idx, zstar, xin0, lam, params), params) + out = { + "idx": idx, + "y": y, + "params": params, + "gt": gt, + "z_step_abs": step_abs, + "z_step_rel": step_res, + "z_force_rel": force_res, + "relax_steps": steps, + "gmres_rel": gmres_rel, + "gmres_info": gmres_info, + "gmres_iters": gmres_iters, + "adj_mu": adj_mu, + "ce_loss": ce_loss, + } + if compute_bptt: + set_param_requires_grad(blk, True) + gB = bptt_step(blk, idx, y, cfg.T1, cfg.eps, 0.0) + out["gBv"] = flat_grad_by_param_id(gB, params) + set_param_requires_grad(blk, True) + return out + + +def draw_seeded_train_batch(cfg, seed): + torch.manual_seed(int(seed)) + return L.get_batch("train", cfg.B, cfg.T) + + +def finite_range(vals): + ok = [float(v) for v in vals if v is not None and math.isfinite(float(v))] + if not ok: + return None + return float(np.mean(ok)), float(np.min(ok)), float(np.max(ok)) + + +def read_multi_batch(rows): + vals = [r.get("cos_ep_t") for r in rows if r.get("ok")] + stats = finite_range(vals) + if stats is None: + return "no successful batches" + mean, mn, mx = stats + spread = mx - mn + if mn > 0.95: + return "consistently aligned across batches" + if mx < 0.80: + return "systematically low across batches" + if spread > 0.20: + return "batch-variance/outlier behavior is material" + return "mostly systematic with moderate batch variance" + + +def read_beta_sweep(rows): + ok = [(r["beta"], r["cos"]) for r in rows if r.get("ok")] + if not ok: + return "no successful beta points" + first_beta, first_cos = ok[0] + last_beta, last_cos = ok[-1] + best_cos = max(c for _, c in ok) + if last_cos > 0.95 and last_cos - first_cos > 0.10: + return f"finite-beta bias likely: cos improves from beta={first_beta:g} to beta={last_beta:g}" + if best_cos < 0.80: + return "cos stays low as beta shrinks: structural/bug more likely than finite-beta bias" + if last_cos > first_cos + 0.05: + return "some finite-beta sensitivity, but not a clean convergence-to-1 result" + return "no strong beta-to-zero improvement" + + +def read_ablation(rows): + ok = [r for r in rows if r.get("ok")] + if not ok: + return "no successful ablations" + full = next((r for r in ok if r["key"] == "full"), None) + best = max(ok, key=lambda r: r["cos"]) + if full is None: + return f"best successful config is {best['label']}" + delta = best["cos"] - full["cos"] + if delta <= 0.05: + return "no ablation materially improves over FULL" + if best["key"] == "track_off": + return "tracking path is suspect: disabling blk.track improved cos" + if best["key"] == "plain": + return "holomorphic/adaptive path is suspect: plain real EP improved cos" + if best["key"] == "fixed_t2": + return "adaptive-T2 selection/tracking is suspect: fixed T2 improved cos" + return f"{best['label']} is the strongest improvement over FULL" + + +def print_diagnostic_summary(multi_rows, beta_rows, ablation_rows): + print("", flush=True) + print("================ DIAGNOSTIC SUMMARY ================", flush=True) + + multi_stats_t = finite_range([r.get("cos_ep_t") for r in multi_rows if r.get("ok")]) + multi_stats_b = finite_range([r.get("cos_ep_b") for r in multi_rows if r.get("ok")]) + if multi_stats_t is None: + print("Multi-batch: no successful batches", flush=True) + else: + mean, mn, mx = multi_stats_t + print(f"Multi-batch: mean cos(g_EP,g_transpose)={mean:+.6f} range=[{mn:+.6f}, {mx:+.6f}]", flush=True) + if multi_stats_b is not None: + mean, mn, mx = multi_stats_b + print(f"Multi-batch: mean cos(g_EP,g_BPTT)={mean:+.6f} range=[{mn:+.6f}, {mx:+.6f}]", flush=True) + print(f"Multi-batch read: {read_multi_batch(multi_rows)}", flush=True) + + print("Beta sweep (beta | cos(g_EP,g_transpose)):", flush=True) + if beta_rows: + for row in beta_rows: + if row.get("ok"): + print(f" {row['beta']:<8g} | {row['cos']:+.6f}", flush=True) + else: + print(f" {row.get('beta', 'n/a')!s:<8} | failed: {row.get('error')}", flush=True) + else: + print(" none", flush=True) + print(f"Beta sweep read: {read_beta_sweep(beta_rows)}", flush=True) + + print("Ablation (config | cos(g_EP,g_transpose)):", flush=True) + if ablation_rows: + for row in ablation_rows: + if row.get("ok"): + print(f" {row['label']} | {row['cos']:+.6f}", flush=True) + else: + print(f" {row.get('label', 'unknown')} | failed: {row.get('error')}", flush=True) + else: + print(" none", flush=True) + print(f"Ablation read: {read_ablation(ablation_rows)}", flush=True) + print("============== END DIAGNOSTIC SUMMARY ==============", flush=True) + + +def run_diagnostics(blk, cfg, ck): + print("=== DIAGNOSTIC MODE ===", flush=True) + print(f"# ckpt step {ck.get('step')} best {ck.get('best')}", flush=True) + print( + "ep_step paths: holo=2,t2sel>0,track=True -> holo_a_track; " + "holo=2,t2sel>0,track=False -> holo_a_select2; holo>0,t2sel=0 -> holo_a; holo=0 -> plain EP", + flush=True, + ) + print("gradient comparison scope: blk.block parameters; readout Wh is excluded", flush=True) + + multi_rows = [] + beta_rows = [] + ablation_rows = [] + seed1000_ref = None + + print("=== DIAGNOSTIC 1: MULTI-BATCH ===", flush=True) + for i in range(6): + seed = 1000 + i + label = f"diag1 batch={i} seed={seed}" + try: + idx, y = draw_seeded_train_batch(cfg, seed) + ref = exact_reference_for_batch(blk, idx, y, cfg, label, compute_bptt=True) + torch.manual_seed(seed) + gEPv, ep_res = run_ep_step_flat(blk, idx, y, cfg, ref["params"]) + row = { + "ok": True, + "batch": i, + "seed": seed, + "cos_ep_t": cos(gEPv, ref["gt"]), + "cos_ep_b": cos(gEPv, ref["gBv"]), + "cos_t_b": cos(ref["gt"], ref["gBv"]), + "z_step_abs": ref["z_step_abs"], + "z_step_rel": ref["z_step_rel"], + "z_force_rel": ref["z_force_rel"], + "ep_res": ep_res, + } + multi_rows.append(row) + print( + f"{label}: cos(g_EP,g_transpose)={row['cos_ep_t']:+.6f} " + f"cos(g_EP,g_BPTT)={row['cos_ep_b']:+.6f} " + f"cos(g_transpose,g_BPTT)={row['cos_t_b']:+.6f} " + f"z_res_abs={row['z_step_abs']:.6e} z_res_rel={row['z_step_rel']:.6e} ep_res={ep_res:.6e}", + flush=True, + ) + if seed == 1000: + seed1000_ref = ref + except Exception as err: + row = {"ok": False, "batch": i, "seed": seed, "error": repr(err)} + multi_rows.append(row) + print(f"{label} failed: {err!r}", flush=True) + + multi_stats_t = finite_range([r.get("cos_ep_t") for r in multi_rows if r.get("ok")]) + multi_stats_b = finite_range([r.get("cos_ep_b") for r in multi_rows if r.get("ok")]) + if multi_stats_t is not None and multi_stats_b is not None: + mt, mint, maxt = multi_stats_t + mb, minb, maxb = multi_stats_b + print( + f"DIAG1 aggregate: cos(g_EP,g_transpose) mean={mt:+.6f} min={mint:+.6f} max={maxt:+.6f}; " + f"cos(g_EP,g_BPTT) mean={mb:+.6f} min={minb:+.6f} max={maxb:+.6f}", + flush=True, + ) + + if seed1000_ref is None: + try: + idx, y = draw_seeded_train_batch(cfg, 1000) + seed1000_ref = exact_reference_for_batch(blk, idx, y, cfg, "diag seed=1000 fallback", compute_bptt=True) + except Exception as err: + print(f"seed=1000 reference failed; beta sweep and ablation cannot run: {err!r}", flush=True) + + print("=== DIAGNOSTIC 2: BETA SWEEP ===", flush=True) + if seed1000_ref is not None: + for beta in [0.04, 0.02, 0.01, 0.005, 0.002]: + try: + torch.manual_seed(1000) + gEPv, ep_res = run_ep_step_flat( + blk, + seed1000_ref["idx"], + seed1000_ref["y"], + cfg, + seed1000_ref["params"], + beta=beta, + hr=beta, + ) + row = {"ok": True, "beta": beta, "cos": cos(gEPv, seed1000_ref["gt"]), "ep_res": ep_res} + beta_rows.append(row) + print(f"beta={beta:g} hr={beta:g}: cos(g_EP,g_transpose)={row['cos']:+.6f} ep_res={ep_res:.6e}", flush=True) + except Exception as err: + beta_rows.append({"ok": False, "beta": beta, "error": repr(err)}) + print(f"beta={beta:g} failed: {err!r}", flush=True) + else: + print("DIAG2 skipped: seed=1000 reference unavailable", flush=True) + + print("=== DIAGNOSTIC 3: COMPONENT ABLATION ===", flush=True) + if seed1000_ref is not None: + ablations = [ + { + "key": "full", + "label": "FULL holo=2 track=True t2sel=40", + "kwargs": {"holo": 2, "track": True, "t2sel": 40, "hr": cfg.hr, "beta": cfg.beta}, + }, + { + "key": "track_off", + "label": "holo=2 track=False t2sel=40", + "kwargs": {"holo": 2, "track": False, "t2sel": 40, "hr": cfg.hr, "beta": cfg.beta}, + }, + { + "key": "plain", + "label": "plain EP holo=0 track ignored t2sel=0", + "kwargs": {"holo": 0, "track": getattr(blk, "track", False), "t2sel": 0, "hr": cfg.hr, "beta": cfg.beta}, + }, + { + "key": "fixed_t2", + "label": f"holo=2 track=True t2sel=0 fixed T2={cfg.T2}", + "kwargs": {"holo": 2, "track": True, "t2sel": 0, "hr": cfg.hr, "beta": cfg.beta, "T2": cfg.T2}, + }, + ] + for item in ablations: + try: + torch.manual_seed(1000) + gEPv, ep_res = run_ep_step_flat( + blk, + seed1000_ref["idx"], + seed1000_ref["y"], + cfg, + seed1000_ref["params"], + **item["kwargs"], + ) + row = { + "ok": True, + "key": item["key"], + "label": item["label"], + "cos": cos(gEPv, seed1000_ref["gt"]), + "ep_res": ep_res, + } + ablation_rows.append(row) + print(f"{item['label']}: cos(g_EP,g_transpose)={row['cos']:+.6f} ep_res={ep_res:.6e}", flush=True) + except Exception as err: + row = {"ok": False, "key": item["key"], "label": item["label"], "error": repr(err)} + ablation_rows.append(row) + print(f"config {item['label']} failed: {err!r}", flush=True) + else: + print("DIAG3 skipped: seed=1000 reference unavailable", flush=True) + + print_diagnostic_summary(multi_rows, beta_rows, ablation_rows) + + +def compare_exact_adjoint(blk, idx, y, zstar, xin0, op, cfg): + print("=== exact-adjoint gradient comparison ===", flush=True) + ell, ce_loss = ce_state_grad(blk, zstar, y) + print(f"CE(z*)={ce_loss:.6f} ||ell||={norm(ell).item():.6e}", flush=True) + lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg) + print( + f"adjoint solve summary: residual={gmres_rel:.3e} iters={gmres_iters} info={gmres_info} " + f"tikhonov_mu={adj_mu:.3e}", + flush=True, + ) + + params = block_param_list(blk) + gt = flat_grad_by_param_id(exact_transpose_grad(blk, idx, zstar, xin0, lam, params), params) + print("gradient comparison scope: blk.block parameters; readout Wh is excluded", flush=True) + + gB = bptt_step(blk, idx, y, cfg.T1, cfg.eps, 0.0) + gEP, ep_res = ep_step( + blk, + idx, + y, + cfg.T1, + cfg.T2, + cfg.eps, + cfg.beta, + 0.0, + cfg.holo, + cfg.hr, + cfg.t1max, + cfg.res_est, + cfg.t2sel, + 1, + 0.0, + ) + gBv = flat_grad_by_param_id(gB, params) + gEPv = flat_grad_by_param_id(gEP, params) + + print(f"EP estimator free-phase residual from ep_step={ep_res:.6e}", flush=True) + print(f"||g_transpose||={norm(gt).item():.6e} ||g_BPTT||={norm(gBv).item():.6e} ||g_EP||={norm(gEPv).item():.6e}", flush=True) + c_t_b = cos(gt, gBv) + d_t_b = rel_diff(gt, gBv) + c_ep_t = cos(gEPv, gt) + d_ep_t = rel_diff(gEPv, gt) + c_ep_b = cos(gEPv, gBv) + d_ep_b = rel_diff(gEPv, gBv) + print(f"cos(g_transpose, g_BPTT)={c_t_b:+.6f} ||g_transpose-g_BPTT||/||g_BPTT||={d_t_b:.6e}", flush=True) + print(f"cos(g_EP, g_transpose)={c_ep_t:+.6f} ||g_EP-g_transpose||/||g_transpose||={d_ep_t:.6e}", flush=True) + print(f"cos(g_EP, g_BPTT)={c_ep_b:+.6f} ||g_EP-g_BPTT||/||g_BPTT||={d_ep_b:.6e}", flush=True) + print("interpretation:", flush=True) + print(" cos(g_transpose,g_BPTT)~1 AND cos(g_EP,g_transpose)~1 -> our EP IS the exact adjoint; failure is convergence/contraction", flush=True) + print(" cos(g_transpose,g_BPTT)~1 AND cos(g_EP,g_transpose)<1 -> exact adjoint works, our EP falls short -> implement exact/dyadic", flush=True) + print(" cos(g_transpose,g_BPTT)<1 -> even exact adjoint != BPTT -> finite-time/convergence, not the adjoint", flush=True) + + +def main(): + cfg = parse_args() + cfg.ckpt = resolve_ckpt_path(cfg.ckpt) + require_cuda_if_requested(cfg.device) + dev = torch.device("cuda:0" if cfg.device == "cuda" else "cpu") + torch.backends.cuda.matmul.allow_tf32 = bool(cfg.tf32) + torch.backends.cudnn.allow_tf32 = bool(cfg.tf32) + print(f"# asym_probe device={dev} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r}", flush=True) + print( + f"# ckpt={cfg.ckpt} B={cfg.B} T={cfg.T} C={cfg.C} H={cfg.H} Mm={cfg.Mm} " + f"attn_mode=thick qknorm=True gelu={cfg.gelu}", + flush=True, + ) + blk, ck = build_block(cfg, dev) + if cfg.diag: + run_diagnostics(blk, cfg, ck) + return + + idx, y = L.get_batch("train", cfg.B, cfg.T) + xin0 = blk.embed(idx).detach() + zstar, steps, step_res, force_res = relax_to_fixed_point(blk, xin0, cfg) + print(f"# ckpt step {ck.get('step')} best {ck.get('best')}", flush=True) + print(f"z* residual: step_res={step_res:.6e} force_res={force_res:.6e} relax_steps={steps}", flush=True) + if step_res > cfg.res_est: + print(f"WARNING: fixed-point target not reached: step_res={step_res:.3e} > {cfg.res_est:.3e}", flush=True) + if step_res > 1e-3 or force_res > 1e-3: + print("WARNING: relaxed z* residual exceeds 1e-3; do not trust exact-adjoint solves until convergence improves", flush=True) + + # Freeze parameters for state Jacobian products. tforce is out-of-place; each + # VJP re-leafs z* to avoid stale graphs, and xin0 is held detached/fixed. + set_param_requires_grad(blk, False) + print("autograd note: using blk.tforce directly; no in-place tforce ops patched; z* is re-leafed per VJP/JVP", flush=True) + + op0 = Operators(blk, zstar, xin0, cfg, mu=0.0) + if cfg.skiprho: + compare_exact_adjoint(blk, idx, y, zstar, xin0, op0, cfg) + return + + tr_mean, tr_std = estimate_trace_s(op0, cfg.trace_probes) + if cfg.mu >= 0: + mu = float(cfg.mu) + else: + mu = cfg.mu_scale * max(abs(tr_mean), 1e-12) + print(f"trace(S)/n estimate={tr_mean:+.6e} std={tr_std:.3e}", flush=True) + print(f"mu used={mu:.6e} (mu_scale={cfg.mu_scale:g}, solve operator S+muI)", flush=True) + + op = Operators(blk, zstar, xin0, cfg, mu=mu) + sensitivity_probe(op, mu) + t0 = time.time() + rho_power = power_rho(op, cfg) + rho_arnoldi = arnoldi_rho(op, cfg.arnoldi_k) + sigma = power_sigma(op, cfg) + elapsed = time.time() - t0 + rho = max(rho_power, rho_arnoldi if rho_arnoldi is not None else 0.0) + print(op.solve_log.summary(), flush=True) + print("non-normal note: power iteration reports dominant growth; Rayleigh trend may be small/oscillatory for skew modes", flush=True) + print(f"rho(S^-1 A)={rho:.6e} power={rho_power:.6e} arnoldi={rho_arnoldi if rho_arnoldi is not None else float('nan'):.6e}", flush=True) + print(f"||S^-1 A||_2={sigma:.6e}", flush=True) + print(f"elapsed_operator_seconds={elapsed:.1f}", flush=True) + verdict = "higher-order AEP viable" if rho < 1.0 else "higher-order AEP not viable" + print(f"VERDICT: rho {'<' if rho < 1.0 else '>='} 1 => {verdict}", flush=True) + compare_exact_adjoint(blk, idx, y, zstar, xin0, op, cfg) + + +if __name__ == "__main__": + main() diff --git a/ep_run/auto_probe.py b/ep_run/auto_probe.py new file mode 100644 index 0000000..90b969e --- /dev/null +++ b/ep_run/auto_probe.py @@ -0,0 +1,25 @@ +"""Wait for the first converged clean-EP ckpt, run the fixed oracle-adjoint probe on it, report g_EP vs g_transpose.""" +import time, os, subprocess, shutil +WD = "/home/yurenh2/ept/ep_run"; os.chdir(WD) +CK, FROZEN = "runs/ep_clean.pt", "runs/ep_clean_probe.pt" +got = False +for _ in range(45): # up to ~67 min + time.sleep(90) + if os.path.exists(CK) and os.path.getsize(CK) > 1_000_000: + got = True; break +if not got: + print("=== AUTO-PROBE: ep_clean.pt never appeared in ~67min ==="); raise SystemExit +shutil.copy2(CK, FROZEN) # freeze (avoid write race) +env = dict(os.environ, CUDA_VISIBLE_DEVICES="0", PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True") +try: + r = subprocess.run(["python3", "asym_probe.py", "--ckpt", FROZEN, "--B", "8"], + env=env, capture_output=True, text=True, timeout=1200) + txt = r.stdout + "\n" + r.stderr +except Exception as e: + txt = f"probe run error: {e}" +print("=== AUTO ORACLE PROBE on first clean-EP ckpt ===") +KEEP = ("z* res", "GMRES", "resid", "cos(", "g_transpose", "g_EP", "g_BPTT", "interpret", "->", "exact", "AsymEP", "# ckpt", "step ") +DROP = ("UserWarning", "cuBLAS", "warnings.warn", "Triggered", "return Variable", "FutureWarning") +for line in txt.splitlines(): + if any(k in line for k in KEEP) and not any(b in line for b in DROP): + print(line) diff --git a/ep_run/bench_gpu2.py b/ep_run/bench_gpu2.py new file mode 100644 index 0000000..ef66ef4 --- /dev/null +++ b/ep_run/bench_gpu2.py @@ -0,0 +1,33 @@ +"""Safe GPU-2 benchmark wrapper for the a-select speed test. + +Runs test_aselect_deepdive.py's main() on GPU 2 (shared with japardi2's NV-Embed +server) with a HARD allocator cap + a start-time free-memory guard so we can never +OOM the neighbour. Forwards all CLI args to the underlying test. + +Usage: + CUDA_VISIBLE_DEVICES=2 MEMFRAC=0.010 python3 bench_gpu2.py --B 1 --T2 80 --T1 2 +""" +import os, sys, torch, runpy + +torch.cuda.init() +free0, total = torch.cuda.mem_get_info() +f0 = free0 / 1024**2 +tot = total / 1024**2 +print(f"[guard] GPU free at start = {f0:.0f} MiB (of {tot:.0f})", flush=True) + +MIN_FREE = float(os.environ.get("MINFREE", "1100")) +if f0 < MIN_FREE: + sys.exit(f"[guard] ABORT: free {f0:.0f} < {MIN_FREE:.0f} MiB — too risky for the neighbour, back off.") + +frac = float(os.environ.get("MEMFRAC", "0.010")) +torch.cuda.set_per_process_memory_fraction(frac) +cap = frac * tot +print(f"[guard] allocator hard-capped at {cap:.0f} MiB (frac={frac}); leaving >= {f0-cap-700:.0f} MiB headroom after ~700 MiB ctx", flush=True) + +# forward remaining argv to the test's main() +sys.argv = ["test_aselect_deepdive.py"] + sys.argv[1:] +try: + runpy.run_path("/home/yurenh2/ept/ep_run/test_aselect_deepdive.py", run_name="__main__") +finally: + free1, _ = torch.cuda.mem_get_info() + print(f"[guard] GPU free at end = {free1/1024**2:.0f} MiB; my peak reserved = {torch.cuda.max_memory_reserved()/1024**2:.0f} MiB", flush=True) diff --git a/ep_run/bf16_dbg.py b/ep_run/bf16_dbg.py new file mode 100644 index 0000000..de32aed --- /dev/null +++ b/ep_run/bf16_dbg.py @@ -0,0 +1,29 @@ +import torch, time, math, traceback +import lt_ep_train as LT +torch.manual_seed(0) +blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 +ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda') +with torch.no_grad(): + for p,s in zip(blk.allp,ck['allp']): p.copy_(s.to('cuda')) +idx,y=LT.get_batch('train',8,256) +base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=80,t1max=150,res_est=1e-4,resreg=0.2) +g32,_=LT.ep_step(blk,idx,y,**base) +def cos(ga): + n=da=db=0.0 + for p in blk.block: + a=ga.get(id(p)); b=g32.get(id(p)) + if a is None or b is None: continue + a=a.float(); b=b.float(); n+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum()) + return n/(math.sqrt(da*db)+1e-20) +def T(fn,reps=2): + fn(); torch.cuda.synchronize(); t0=time.time() + for _ in range(reps): fn() + torch.cuda.synchronize(); return round((time.time()-t0)/reps*1000) +print("fp32 step ms:", T(lambda: LT.ep_step(blk,idx,y,**base)),flush=True) +print("=== A: blanket autocast (locate the break) ===",flush=True) +try: + with torch.autocast('cuda',dtype=torch.bfloat16): gA,_=LT.ep_step(blk,idx,y,**base) + print("A OK cos",round(cos(gA),4),"ms",T(lambda: (lambda: [LT.ep_step(blk,idx,y,**base) for _ in '1'])() )) +except Exception: + traceback.print_exc() +print("DONE",flush=True) diff --git a/ep_run/bf16_dbg2.py b/ep_run/bf16_dbg2.py new file mode 100644 index 0000000..517642d --- /dev/null +++ b/ep_run/bf16_dbg2.py @@ -0,0 +1,30 @@ +import torch, time, math +import lt_ep_train as LT +torch.manual_seed(0) +blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 +ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda') +with torch.no_grad(): + for p,s in zip(blk.allp,ck['allp']): p.copy_(s.to('cuda')) +idx,y=LT.get_batch('train',8,256) +base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=80,t1max=150,res_est=1e-4,resreg=0.2) +g32,_=LT.ep_step(blk,idx,y,**base) +def cos(ga): + n=da=db=0.0 + for p in blk.block: + a=ga.get(id(p)); b=g32.get(id(p)) + if a is None or b is None: continue + a=a.float(); b=b.float(); n+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum()) + return round(n/(math.sqrt(da*db)+1e-20),4) +def T(fn,reps=2): + fn(); torch.cuda.synchronize(); t0=time.time() + for _ in range(reps): fn() + torch.cuda.synchronize(); return round((time.time()-t0)/reps*1000) +fp32ms=T(lambda: LT.ep_step(blk,idx,y,**base)); print("fp32:",fp32ms,"ms",flush=True) +def trial(name,**ac): + try: + def run(): + with torch.autocast('cuda',dtype=torch.bfloat16,**ac): return LT.ep_step(blk,idx,y,**base)[0] + g=run(); print(f"{name:22s} OK cos={cos(g)} ms={T(run)} (fp32={fp32ms})",flush=True) + except Exception as e: print(f"{name:22s} FAIL: {type(e).__name__}: {str(e)[:80]}",flush=True) +trial("cache_enabled=False", cache_enabled=False) +print("DONE",flush=True) diff --git a/ep_run/bias_var.py b/ep_run/bias_var.py new file mode 100644 index 0000000..a616ce4 --- /dev/null +++ b/ep_run/bias_var.py @@ -0,0 +1,63 @@ +"""Decisive 'why is EP far at S1' diagnostic: separate estimator BIAS from VARIANCE at the +converged v4b checkpoint. Over N batches compute EP grad, BPTT-400 grad, BPTT-150 control. + mean-cos = mean_b cos(g_EP^b, g_BPTT^b) -> per-step quality (noisy) + cos-means = cos(sum_b g_EP, sum_b g_BPTT) -> if >> mean-cos: errors AVERAGE OUT = VARIANCE + if ~ mean-cos: systematic = BIAS (the real wall) +BPTT-150-vs-400 gives the same two metrics as the slow-mixing horizon baseline.""" +import torch +import lt_ep_train as M +from pathlib import Path +import pickle +M.DD = Path('/tmp/lt_ep/data/tinystories') +M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size'] +from lt_ep_train import EQBlock, get_batch, bptt_step, ep_step +dev = 'cuda' +torch.manual_seed(0) +B, T, C, H = 8, 256, 256, 8 +blk = EQBlock(C, H, 256, T, attn_mode='thick') +blk.qknorm = False; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0.0; blk.nbrake = 0.0; blk._cstep = None +ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt') +for p, w in zip(blk.allp, ck['allp']): + with torch.no_grad(): + p.copy_(w.to(dev)) +print(f"v4b ckpt best {ck['best']:.4f}", flush=True) +groups = {'all': blk.block, 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], + 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b], + 'emb': [blk.tok, blk.pos]} +N = 16 +sEP, s400, s150 = {}, {}, {} +cos_b = {k: [] for k in groups} +bctl_b = {k: [] for k in groups} + +def flat(g, ps): + v = [g[id(p)].reshape(-1) for p in ps if g.get(id(p)) is not None] + return torch.cat(v) if v else None + +def cos(a, b): + return (a @ b / (a.norm() * b.norm() + 1e-12)).item() + +for i in range(N): + idx, y = get_batch('train', B, T) + gE, _ = ep_step(blk, idx, y, 150, 20, 0.1, 0.02, 0.0, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120) + g4 = bptt_step(blk, idx, y, 400, 0.1) + g1 = bptt_step(blk, idx, y, 150, 0.1) + for k, ps in groups.items(): + a, b, c = flat(gE, ps), flat(g4, ps), flat(g1, ps) + if a is not None and b is not None: + cos_b[k].append(cos(a, b)) + if c is not None and b is not None: + bctl_b[k].append(cos(c, b)) + for src, acc in ((gE, sEP), (g4, s400), (g1, s150)): + for p in blk.block: + if src.get(id(p)) is not None: + acc[id(p)] = src[id(p)].detach().clone() if id(p) not in acc else acc[id(p)] + src[id(p)].detach() + print(f" batch {i+1}/{N} done", flush=True) + +print(f"\n{'group':>5} {'EP mean-cos':>12} {'EP cos-means':>13} {'BPTT mean-cos':>14} {'BPTT cos-means':>15}") +for k, ps in groups.items(): + mc = sum(cos_b[k]) / len(cos_b[k]) + bmc = sum(bctl_b[k]) / len(bctl_b[k]) + aE, a4, a1 = flat(sEP, ps), flat(s400, ps), flat(s150, ps) + cm = cos(aE, a4) + bcm = cos(a1, a4) + print(f"{k:>5} {mc:>12.3f} {cm:>13.3f} {bmc:>14.3f} {bcm:>15.3f}", flush=True) diff --git a/ep_run/bp_charlm.py b/ep_run/bp_charlm.py new file mode 100644 index 0000000..410d812 --- /dev/null +++ b/ep_run/bp_charlm.py @@ -0,0 +1,78 @@ +"""Same-param standard BP transformer char-LM (reference ceiling). +Standard pre-LN block: MHA + FFN, trained with normal backprop.""" +import argparse, math, pickle, time, numpy as np, torch, torch.nn as nn, torch.nn.functional as F +from pathlib import Path +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +ap = argparse.ArgumentParser() +ap.add_argument('--data', default='/tmp/lt_ep/data/shakespeare_char') +ap.add_argument('--B', type=int, default=32); ap.add_argument('--T', type=int, default=64) +ap.add_argument('--C', type=int, default=128); ap.add_argument('--H', type=int, default=4) +ap.add_argument('--depth', type=int, default=1); ap.add_argument('--mlp', type=int, default=4) +ap.add_argument('--steps', type=int, default=3000); ap.add_argument('--lr', type=float, default=3e-3) +ap.add_argument('--seed', type=int, default=0) +cfg = ap.parse_args() +torch.manual_seed(cfg.seed) +DD = Path(cfg.data) +vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] +B, T, C, H, DEPTH, MLP = cfg.B, cfg.T, cfg.C, cfg.H, cfg.depth, cfg.mlp + + +def get_batch(split): + data = np.memmap(DD / ('train.bin' if split == 'train' else 'val.bin'), dtype=np.uint16, mode='r') + ix = torch.randint(len(data) - T - 1, (B,)) + x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) + return x.to(dev), y.to(dev) + + +class Block(nn.Module): + def __init__(s): + super().__init__() + s.ln1 = nn.LayerNorm(C); s.ln2 = nn.LayerNorm(C) + s.attn = nn.MultiheadAttention(C, H, batch_first=True) + s.mlp = nn.Sequential(nn.Linear(C, MLP * C), nn.GELU(), nn.Linear(MLP * C, C)) + s.register_buffer('m', torch.triu(torch.ones(T, T) * float('-inf'), 1)) + + def forward(s, x): + h = s.ln1(x) + x = x + s.attn(h, h, h, attn_mask=s.m[:x.size(1), :x.size(1)], need_weights=False)[0] + return x + s.mlp(s.ln2(x)) + + +class GPT(nn.Module): + def __init__(s): + super().__init__() + s.tok = nn.Embedding(vocab, C); s.pos = nn.Embedding(T, C) + s.blocks = nn.ModuleList([Block() for _ in range(DEPTH)]) + s.lnf = nn.LayerNorm(C); s.head = nn.Linear(C, vocab, bias=False) + + def forward(s, idx, y=None): + x = s.tok(idx) + s.pos(torch.arange(idx.size(1), device=dev)) + for b in s.blocks: + x = b(x) + logits = s.head(s.lnf(x)) + loss = None if y is None else F.cross_entropy(logits.reshape(-1, vocab), y.reshape(-1)) + return logits, loss + + +m = GPT().to(dev) +opt = torch.optim.AdamW(m.parameters(), lr=cfg.lr, weight_decay=1e-4) +STEPS = cfg.steps +sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, STEPS, eta_min=cfg.lr * 0.1) +np_ = sum(p.numel() for p in m.parameters()) +print(f"[bp-charlm] params={np_/1e3:.1f}K depth={DEPTH} C={C} H={H} mlp={MLP}", flush=True) + + +@torch.no_grad() +def ev(): + m.eval(); t = sum(m(*get_batch('val'))[1].item() for _ in range(20)) / 20; m.train(); return t + + +best, t0 = 9.9, time.time() +for step in range(1, STEPS + 1): + _, loss = m(*get_batch('train')) + opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step() + if step % 200 == 0 or step == STEPS: + v = ev(); best = min(best, v) + print(f"step {step:4d}/{STEPS} | val CE {v:.4f} (best {best:.4f}) | {step/(time.time()-t0):.1f} it/s", flush=True) +print(f"[bp-charlm] DONE best val CE {best:.4f} (random ln({vocab})={math.log(vocab):.3f})", flush=True) diff --git a/ep_run/compile_bench.py b/ep_run/compile_bench.py new file mode 100644 index 0000000..dc730dd --- /dev/null +++ b/ep_run/compile_bench.py @@ -0,0 +1,44 @@ +import torch, pickle, time +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=24; T=256; eps=0.1; T1=150 +torch.manual_seed(0); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev) +blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True +with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + xin=blk.embed(idx).detach() +print("torch", torch.__version__) +def relax(step,n): + z=xin.clone() + for _ in range(n): z=step(z) + return z +@torch.no_grad() +def bench(step,label,reps=8): + for _ in range(3): relax(step,T1); torch.cuda.synchronize() # warmup + torch.cuda.synchronize(); t0=time.time() + for _ in range(reps): relax(step,T1) + torch.cuda.synchronize(); return (time.time()-t0)/reps*1000 +def fstep(z): + with torch.no_grad(): return z+eps*blk.force(z,xin).detach() +te=bench(fstep,"eager") +print(f"FP32 eager 150-relax: {te:.1f} ms/relax") +try: + cstep=torch.compile(fstep) + tc=bench(cstep,"compiled") + print(f"FP32 compile : {tc:.1f} ms/relax -> {te/tc:.2f}x") +except Exception as e: print("compile default ERR:", str(e)[:200]) +try: + cstep2=torch.compile(fstep, mode="reduce-overhead") + tc2=bench(cstep2,"reduce-overhead") + print(f"FP32 reduce-overhead : {tc2:.1f} ms/relax -> {te/tc2:.2f}x (CUDA graphs)") +except Exception as e: print("reduce-overhead ERR:", str(e)[:200]) +# bf16 potential (timing only; precision caveat for actual use) +def fstep_bf(z): + with torch.no_grad(), torch.autocast('cuda',dtype=torch.bfloat16): return z+eps*blk.force(z,xin).detach() +try: + tb=bench(fstep_bf,"bf16"); print(f"BF16 eager (autocast): {tb:.1f} ms/relax -> {te/tb:.2f}x (precision caveat)") +except Exception as e: print("bf16 ERR:", str(e)[:200]) +print("=== DONE ===") diff --git a/ep_run/cos_monitor.py b/ep_run/cos_monitor.py new file mode 100644 index 0000000..12a4fe9 --- /dev/null +++ b/ep_run/cos_monitor.py @@ -0,0 +1,49 @@ +"""Lightweight cos monitor for ep_hr02: probe each new ckpt, log step->cos(g_EP,exact-adjoint). +Fire on: cos degrades <0.90 (gradient going bad) / survived to step>=9500 (cleared old danger zone) / death.""" +import time, os, re, subprocess, shutil +WD = "/home/yurenh2/ept/ep_run"; os.chdir(WD) +LOG, CK, FROZEN, COSLOG, PID = "runs/ep_hr02.log", "runs/ep_hr02.pt", "runs/ep_hr02_cosprobe.pt", "runs/cos_monitor.log", 1684249 +def alive(): + try: os.kill(PID, 0); return True + except Exception: return False +def cur_step(): + try: + ls = [l for l in open(LOG) if l.startswith("step")] + if ls: return int(re.search(r"step (\d+)", ls[-1]).group(1)) + except Exception: pass + return 0 +open(COSLOG, "a").write(f"# cos monitor start (ep_hr02, hr=0.2)\n") +last = -1; fired = None; traj = []; t0 = time.time() +while fired is None and time.time() - t0 < 18 * 3600: + time.sleep(120) + if not alive(): fired = f"ep_hr02 EXITED at step {cur_step()}"; break + step = cur_step() + if step >= last + 450 and os.path.exists(CK) and os.path.getsize(CK) > 1e6: + try: shutil.copy2(CK, FROZEN) + except Exception: continue + last = step + env = dict(os.environ, CUDA_VISIBLE_DEVICES="0", PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True") + cosv, zres = None, "?" + try: + r = subprocess.run(["python3", "asym_probe.py", "--ckpt", FROZEN, "--B", "8"], + env=env, capture_output=True, text=True, timeout=600) + out = r.stdout + r.stderr + m = re.search(r"cos\(g_EP, ?g_transpose\)=([+-][0-9.]+)", out) + zr = re.search(r"z\* residual.*?step_rel=([0-9.eE+-]+)", out) + cosv = float(m.group(1)) if m else None + zres = zr.group(1) if zr else "?" + except Exception as e: + zres = f"probe-err:{e}" + # also grab current val from the training log + val = "?" + try: val = re.search(r"val CE ([\d.eE+-]+)", [l for l in open(LOG) if l.startswith("step")][-1]).group(1) + except Exception: pass + line = f"step {step}: cos={cosv} val={val} z_res={zres}" + traj.append((step, cosv, val)); open(COSLOG, "a").write(line + "\n"); print(line, flush=True) + if cosv is not None and cosv < 0.82: # below the historical per-batch floor (~0.85) => real degradation, not variance + fired = f"COS DEGRADED to {cosv:.3f} at step {step} (val {val}) — below historical floor, real gradient degradation"; break + if step >= 9500: + fired = f"ep_hr02 SURVIVED to step {step} (val {val}) with cos staying high — cleared the old ~9200 danger zone"; break +print("=== COS MONITOR FIRED ==="); print("trigger:", fired or "18h timeout") +print("cos trajectory (step | cos | val):") +for s, c, v in traj: print(f" {s:6d} | {c} | {v}") diff --git a/ep_run/cos_sweep.log b/ep_run/cos_sweep.log new file mode 100644 index 0000000..c3a37c2 --- /dev/null +++ b/ep_run/cos_sweep.log @@ -0,0 +1,15 @@ +loaded resreg_warm allp (step 27550, best 1.9704) +batch 0 done +batch 1 done +batch 2 done + +setting cos(block,vs BPTT) ms/step(B8) +baseline 0.9709 4641 +t2sel80 0.9633 2600 +t2sel40 0.9323 1578 +corr_every2 0.9708 4649 +corr_every4 0.9709 4654 +plain_nudge_holo0 0.2985 1127 +no_t1max_refine 0.9689 4398 +T1=80 0.9640 4721 +DONE diff --git a/ep_run/cos_sweep.py b/ep_run/cos_sweep.py new file mode 100644 index 0000000..cd628fe --- /dev/null +++ b/ep_run/cos_sweep.py @@ -0,0 +1,35 @@ +import torch, time, math, collections +import lt_ep_train as LT +torch.manual_seed(0) +blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick') +blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 +ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda') +with torch.no_grad(): + for p,s in zip(blk.allp, ck['allp']): p.copy_(s.to('cuda')) +print(f"loaded resreg_warm allp (step {ck['step']}, best {ck['best']:.4f})",flush=True) +def cosine(ga,gb,params): + num=da=db=0.0 + for p in params: + a=ga.get(id(p)); b=gb.get(id(p)) + if a is None or b is None: continue + num+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum()) + return num/(math.sqrt(da*db)+1e-20) +base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=160,t1max=300,res_est=1e-4,resreg=0.2) +settings={'baseline':{}, 't2sel80':dict(t2sel=80), 't2sel40':dict(t2sel=40), + 'corr_every2':dict(corr_every=2), 'corr_every4':dict(corr_every=4), + 'plain_nudge_holo0':dict(holo=0,t2sel=0), 'no_t1max_refine':dict(t1max=0), 'T1=80':dict(T1=80)} +cos=collections.defaultdict(list); NB=3; B=8 +for bi in range(NB): + idx,y=LT.get_batch('train',B,256) + gref=LT.bptt_step(blk,idx,y,300,0.1) + for n,kw in settings.items(): + g,_=LT.ep_step(blk,idx,y,**{**base,**kw}); cos[n].append(cosine(g,gref,blk.block)) + print(f"batch {bi} done",flush=True) +idx,y=LT.get_batch('train',B,256) +def T(fn): + fn(); torch.cuda.synchronize(); t0=time.time(); fn(); torch.cuda.synchronize(); return round((time.time()-t0)*1000) +print(f"\n{'setting':22s}{'cos(block,vs BPTT)':>20s}{'ms/step(B8)':>13s}",flush=True) +for n,kw in settings.items(): + c=sum(cos[n])/len(cos[n]); t=T(lambda: LT.ep_step(blk,idx,y,**{**base,**kw})) + print(f"{n:22s}{c:>20.4f}{t:>13d}",flush=True) +print("DONE",flush=True) diff --git a/ep_run/data_prep.log b/ep_run/data_prep.log new file mode 100644 index 0000000..3dac45a --- /dev/null +++ b/ep_run/data_prep.log @@ -0,0 +1,4 @@ + + + +trained BPE vocab=4096 diff --git a/ep_run/diag_cos.py b/ep_run/diag_cos.py new file mode 100644 index 0000000..37e7257 --- /dev/null +++ b/ep_run/diag_cos.py @@ -0,0 +1,45 @@ +"""#1 — from-scratch-plateau diagnostic: cos(EP gradient, exact BPTT gradient) over training, plus an +operator FINGERPRINT for comparing checkpoints. + +Hypothesis (from the resreg_probe): the scratch run spends its formative high-lr phase at free-phase +residual ~1e-2, where the EP estimate is only ~0.72-aligned with the exact BPTT gradient -> it descends +on a mediocre gradient and plateaus above BPTT's floor; a warm start from a conditioned operator +(res~1e-4, cos~0.98) skips that phase. This logs (step, resT1, cos, val) so scratch vs warm trajectories +can be laid side by side, and fingerprints any checkpoint (res, cos, numerical abscissa, val) so we can +see WHAT distinguishes s2000 from other 2000-step checkpoints (conditioning? alignment? abscissa?).""" +import math, torch +from lt_ep_train import ep_step, bptt_step, relax, evaluate, get_batch +from eig_control import num_abscissa + + +def _cos(ge, gb, params): # cosine over the shared block params (where EP != BPTT) + dot = ne = nb = 0.0 + for p in params: + a, b = ge.get(id(p)), gb.get(id(p)) + if a is None or b is None: continue + dot += float((a * b).sum()); ne += float((a * a).sum()); nb += float((b * b).sum()) + return dot / (math.sqrt(ne * nb) + 1e-20) + + +def cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, bsub=6): + """cos(EP grad, exact BPTT grad) on a SMALL sub-batch (the exact-BPTT unroll graph is memory-heavy at + C512/T1=150, so we slice to `bsub` rows). Both computed jacreg-free for a clean comparison.""" + idx, y = idx[:bsub], y[:bsub] + ge, res = ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=holo, hr=hr, t1max=t1max, + res_est=res_est, t2sel=t2sel, corr_every=1, res_gate=0.0, resreg=0.0) + gb = bptt_step(blk, idx, y, T1, eps, jacreg=0.0) + return _cos(ge, gb, blk.block), res + + +def fingerprint(blk, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, nb=4, B=6): + """Median (res, cos-to-BPTT, numerical abscissa) over nb small batches + val CE — the operator's 4-D fingerprint. + B kept small: the exact-BPTT reference gradient unrolls T1 steps and is memory-heavy at C512.""" + cache = {}; res_l, cos_l, om_l = [], [], [] + for _ in range(nb): + idx, y = get_batch('train', B, blk.T) + c, r = cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo, hr, t1max, res_est, t2sel) + xin = blk.embed(idx).detach(); zs = relax(blk, xin.clone(), xin, T1, eps) + _, om = num_abscissa(blk, zs, cache) + res_l.append(r); cos_l.append(c); om_l.append(om) + md = lambda a: sorted(a)[len(a) // 2] + return dict(res=md(res_l), cos=md(cos_l), num_abscissa=md(om_l), val=evaluate(blk, T1, eps)) diff --git a/ep_run/drift_diag.py b/ep_run/drift_diag.py new file mode 100644 index 0000000..4180104 --- /dev/null +++ b/ep_run/drift_diag.py @@ -0,0 +1,87 @@ +"""Late-drift diagnostic. Every stable EP/BPTT recipe peaks mid-run then val CE drifts up 0.1-0.3. +Train the champion recipe on S0 (fast) and log, every 200 steps, quantities that SEPARATE the +competing hypotheses: + - train_ce vs val_raw : train down + val up => OVERFIT ; both up => OPTIMIZATION INSTABILITY + - val_raw vs val_deep (T1=150 vs 400) : diverge => DYNAMICAL (fixed point degrading off train depth) + - res (free-phase) over time : climbing => contraction lost + - jr (lambda) over time : climbing => CONTROLLER FIGHT + - cos(EP-grad, BPTT-400) on a fixed probe batch : dropping => ESTIMATOR DEGRADATION + - |W|/|W_init| per group + cap-bind frac : group-specific => PARAMETRIC RUNAWAY + - ||ema - raw|| : how far the averaged weights sit from the wandering raw ones +""" +import math, time, torch +import lt_ep_train as M +from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step, relax, evaluate, ce +dev = 'cuda' +torch.manual_seed(0) +B, T, C, H = 32, 64, 128, 4 +blk = EQBlock(C, H, 256, T, attn_mode='thick', c=1.0) +for w in blk.capw: + blk.caps[id(w)] = w.detach().norm().item() * 3.0 +opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) +STEPS = 9000 +sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, STEPS, eta_min=5e-5) +pema = [p.detach().clone() for p in blk.allp] +W0 = {id(p): p.detach().norm().item() for p in blk.allp} +groups = {'emb': [blk.tok, blk.pos], 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], + 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b]} +blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0.0; blk.nbrake = 0.0 + +# fixed probe batch for gradient-cosine-over-training +pidx, py = get_batch('train', 16, T) +def grad_cos(): + ref = bptt_step(blk, pidx, py, 400, 0.1) + g, _ = ep_step(blk, pidx, py, 150, 20, 0.1, 0.02, 0.0, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120) + keep = [p for p in blk.block if g.get(id(p)) is not None and ref.get(id(p)) is not None] + va = torch.cat([g[id(p)].reshape(-1) for p in keep]); vb = torch.cat([ref[id(p)].reshape(-1) for p in keep]) + return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() + +@torch.no_grad() +def deep_val(T1): + tot = 0.0 + for _ in range(8): + ix, yy = get_batch('val', 32, T) + xin = blk.embed(ix).detach() + z = relax(blk, xin.clone(), xin, T1, 0.1) + tot += ce(blk, z, yy).item() + return tot / 8 + +jr, rs = 1.0, None +print(f"{'step':>5} {'train':>6} {'val150':>6} {'val400':>6} {'ema150':>6} {'res':>8} {'jr':>5} " + f"{'cos':>5} {'|emb|':>5} {'|attn|':>6} {'|ffn|':>5} {'|ln|':>5} {'emaΔ':>6}", flush=True) +for step in range(1, STEPS + 1): + idx, y = get_batch('train', B, T) + grads, res = ep_step(blk, idx, y, 150, 20, 0.1, 0.02, jr, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120, res_gate=5e-3) + flo = 0.1 + rs = res if rs is None else 0.9 * rs + 0.1 * res + jr = min(16.0, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / 1.5e-3)))) + if all((g is None) or torch.isfinite(g).all() for g in grads.values()): + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = grads.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) + opt.step(); sched.step() + with torch.no_grad(): + for p in blk.capw: + pn = p.norm(); cap = blk.caps[id(p)] + if pn > cap: + p.mul_(cap / pn) + for s, p in zip(pema, blk.allp): + s.mul_(0.999).add_(p.detach(), alpha=1e-3) + if step % 200 == 0: + with torch.no_grad(): + tc = ce(blk, relax(blk, blk.embed(idx).detach().clone(), blk.embed(idx).detach(), 150, 0.1), y).item() + v150, v400 = deep_val(150), deep_val(400) + raw = [p.detach().clone() for p in blk.allp] + with torch.no_grad(): + for p, s in zip(blk.allp, pema): + p.copy_(s) + ve = deep_val(150) + emad = math.sqrt(sum((r - s).pow(2).sum().item() for r, s in zip(raw, pema))) + for p, r in zip(blk.allp, raw): + p.copy_(r) + gn = {k: math.sqrt(sum(p.detach().norm().item()**2 for p in ps)) / + math.sqrt(sum(W0[id(p)]**2 for p in ps)) for k, ps in groups.items()} + cs = grad_cos() + print(f"{step:>5} {tc:>6.3f} {v150:>6.3f} {v400:>6.3f} {ve:>6.3f} {res:>8.1e} {jr:>5.1f} " + f"{cs:>5.2f} {gn['emb']:>5.2f} {gn['attn']:>6.2f} {gn['ffn']:>5.2f} {gn['ln']:>5.2f} {emad:>6.2f}", flush=True) diff --git a/ep_run/eig_control.py b/ep_run/eig_control.py new file mode 100644 index 0000000..6e3f598 --- /dev/null +++ b/ep_run/eig_control.py @@ -0,0 +1,50 @@ +"""#2 — leading-abscissa control for the ept non-conservative operator (ports the aep-dynamics +'control the LEADING spectral signal, surgically' finding to C512). + +Why not jacreg: jacreg penalizes ||J_nc||_F^2 (Hutchinson) — the WHOLE Jacobian norm. That is blunt: +it over-constrains directions that never threaten stability, and when the controller ramps it high it +HIJACKS the task gradient (the known jr-hijack failure). The aep leading-vs-lagging result says the +right knob is the leading SPECTRAL ABSCISSA, not the norm. + +What we control: the numerical abscissa omega(J_nc) = lambda_max( (J_nc + J_nc^T)/2 ) = the 2-norm +log-norm mu_2(J_nc) = one-sided Lipschitz constant. It upper-bounds the spectral abscissa, governs the +transient growth ||e^{Jt}||, and its crossing past (1+c) IS the free-phase Hopf (J_F = J_nc - (1+c)I). +Power iteration on the SYMMETRIC PART -> matvec-only (jvp+vjp of nc_force, the same primitives jacreg +and the AEP correction already call), so it scales and is analog-compatible (no eigendecomposition). +One-sided ReLU penalty = a LEADING signal: acts only as the abscissa nears the margin, so unlike jacreg +it does not over-contract / hijack when the operator is already safe. +""" +import torch +from torch.autograd.functional import jvp, vjp + + +def num_abscissa(blk, zs, cache, iters=3): + """Power-iterate Sym(J_nc)=(J_nc+J_nc^T)/2 at zs for the leading eigenpair. Returns (v_detached, lambda_float). + lambda = v^T Sym(J_nc) v = v^T J_nc v (Rayleigh quotient at the leading eigenvector) = numerical abscissa.""" + z = zs.detach() + v = cache.get('v') + if v is None or v.shape != z.shape or v.dtype != z.dtype or v.device != z.device: + v = torch.randn_like(z) + v = v / (v.norm() + 1e-12) + with torch.no_grad(): + for _ in range(iters): + Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1]) # Sym(J_nc) v + v = Sv / (Sv.norm() + 1e-12) + lam = float((v * jvp(blk.nc_force, z, v)[1]).sum() / (v * v).sum()) # v^T J_nc v + cache['v'] = v + return v, lam + + +def eig_penalty(blk, zs, eigreg, margin, cache, iters=3): + """Grads of the one-sided leading-abscissa penalty R = eigreg * relu(omega(J_nc) - margin)^2. + Returns ({id(p): grad}, omega) — omega logged as the controller signal. Empty grads when below margin.""" + v, lam0 = num_abscissa(blk, zs, cache, iters) + if lam0 <= margin: # below the stability margin: leading signal off + return {}, lam0 + z = zs.detach() + with torch.enable_grad(): + Jv = jvp(blk.nc_force, z, v, create_graph=True)[1] # differentiable in theta (nc_force params) + lam = (v * Jv).sum() / (v * v).sum() # numerical abscissa, v fixed + R = eigreg * torch.relu(lam - margin) ** 2 + gr = torch.autograd.grad(R, blk.block, allow_unused=True) + return {id(p): g for p, g in zip(blk.block, gr) if g is not None}, lam0 diff --git a/ep_run/eig_jacreg.py b/ep_run/eig_jacreg.py new file mode 100644 index 0000000..14aaedf --- /dev/null +++ b/ep_run/eig_jacreg.py @@ -0,0 +1,38 @@ +import torch, pickle, numpy as np +from pathlib import Path +from scipy.sparse.linalg import LinearOperator, eigs +import lt_ep_train as L +from lt_ep_train import EQBlock, relax +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=2; T=256; eps=0.1 +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +def load_blk(path): + ck=torch.load(path,map_location=dev) + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True + with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + return blk, ck.get('best','?') +@torch.no_grad() +def lead_eigs(blk, relax_steps=250, k=6, hrel=1e-3): + xin=blk.embed(idx).detach(); z=relax(blk,xin.clone(),xin,relax_steps,eps).detach() + F0=blk.force(z,xin).detach(); zn=z.norm().item(); g=F0.norm().item()/(zn+1e-9); h=hrel*zn; shp=z.shape; N=z.numel() + def Jv(vt): + nv=vt.norm().item() + if nv<1e-20: return torch.zeros_like(vt) + return (blk.force(z+h*(vt/nv),xin).detach()-F0)/h*nv + def matvec(v): + vt=torch.from_numpy(np.ascontiguousarray(v).astype('float32')).reshape(shp).to(dev) + return (vt+eps*Jv(vt)).double().cpu().numpy().reshape(-1) + op=LinearOperator((N,N),matvec=matvec,dtype='float64') + vals=eigs(op,k=k,which='LM',return_eigenvectors=False,maxiter=4000,tol=1e-5) + return g, sorted(vals,key=lambda x:-abs(x)) +for tag,path in [("ep_jacreg ~2.75 (ADAPTIVE jacreg)","runs/ep_jacreg.pt"), + ("redx s3200 2.74 (FROZEN jacreg, BLEW)","runs/redx_traj/s3200.pt")]: + try: + blk,best=load_blk(path); g,vals=lead_eigs(blk) + print(f"=== {tag} best={best} g_floor={g:.4f} ===") + for lam in vals[:4]: + mu=(lam-1)/eps + print(f" |lam|={abs(lam):.5f} mu={mu.real:+.4f}{mu.imag:+.4f}j [{'UNSTABLE' if abs(lam)>1+1e-4 else 'STABLE'}{' rot' if abs(lam.imag)>1e-3 else ' real'}{' ReMu<0' if mu.real<-1e-4 else ' ReMu>=0'}]") + except Exception as e: print(f"=== {tag}: ERR {e} ===") +print("=== DONE === ep_jacreg ReMu<0 => jacreg pushed it STABLE where frozen-jacreg redx is ReMu>0 (mechanism confirmed)") diff --git a/ep_run/eig_probe.py b/ep_run/eig_probe.py new file mode 100644 index 0000000..a87864b --- /dev/null +++ b/ep_run/eig_probe.py @@ -0,0 +1,43 @@ +import torch, pickle, numpy as np +from pathlib import Path +from scipy.sparse.linalg import LinearOperator, eigs +import lt_ep_train as L +from lt_ep_train import EQBlock, relax +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=2; T=256; eps=0.1 +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +def load_blk(path): + ck=torch.load(path,map_location=dev) + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True + with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + return blk, ck.get('best','?'), ck.get('step','?') +@torch.no_grad() +def lead_eigs(blk, relax_steps=250, k=8, hrel=1e-3): + xin=blk.embed(idx).detach(); z=relax(blk,xin.clone(),xin,relax_steps,eps).detach() + F0=blk.force(z,xin).detach(); zn=z.norm().item(); gfloor=F0.norm().item()/(zn+1e-9) + shp=z.shape; N=z.numel(); h=hrel*zn + def Jv(vt): # finite-difference JVP: J@vt + nv=vt.norm().item() + if nv<1e-20: return torch.zeros_like(vt) + pert=h*(vt/nv) + return (blk.force(z+pert,xin).detach()-F0)/h*nv + # sanity: random-direction ||Jv|| + vr=torch.randn(shp,device=dev); jr=Jv(vr); sane=jr.norm().item()/(vr.norm().item()+1e-9) + def matvec(v): + vt=torch.from_numpy(np.ascontiguousarray(v).astype('float32')).reshape(shp).to(dev) + return (vt+eps*Jv(vt)).double().cpu().numpy().reshape(-1) + op=LinearOperator((N,N),matvec=matvec,dtype='float64') + vals=eigs(op,k=k,which='LM',return_eigenvectors=False,maxiter=4000,tol=1e-5) + return gfloor, sane, sorted(vals,key=lambda x:-abs(x)) +for tag,path in [("s2000 (healthy)","runs/redx_traj/s2000.pt"), + ("s3200 (blew@2.74)","runs/redx_traj/s3200.pt"), + ("ep_eps05 (blew@2.41)","runs/ep_eps05.pt")]: + blk,best,step=load_blk(path) + g,sane,vals=lead_eigs(blk) + print(f"=== {tag} best={best} step={step} g_floor={g:.4f} ||Jv_rand||/||v||={sane:.3f} ===") + for lam in vals[:5]: + mu=(lam-1)/eps + fl=("UNSTABLE" if abs(lam)>1.0+1e-4 else "stable")+(" CMPLX" if abs(lam.imag)>1e-3 else " real")+(" ReMu<0=Euler-artifact" if mu.real<-1e-4 else " ReMu>=0=TRUE-instab") + print(f" |lam|={abs(lam):.5f} lam={lam.real:+.4f}{lam.imag:+.4f}j mu={mu.real:+.4f}{mu.imag:+.4f}j [{fl}]") +print("=== DONE ===") diff --git a/ep_run/ep_ajr_check.py b/ep_run/ep_ajr_check.py new file mode 100644 index 0000000..5d8143e --- /dev/null +++ b/ep_run/ep_ajr_check.py @@ -0,0 +1,22 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_rr_ajr.log"; TARGET=2.70 +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_rr_ajr.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)",l); mv=re.search(r"val CE ([\d.eE+-]+)",l); mb=re.search(r"best ([\d.eE+-]+)",l) + mr=re.search(r"res=([\d.eE+-]+)",l); mj=re.search(r"jr=([\d.eE+-]+)",l) + if not (ms and mv and mb and mr): return None + return (int(ms.group(1)),float(mv.group(1)),float(mb.group(1)),float(mr.group(1)), mj.group(1) if mj else "?") +fired=None; t0=time.time() +while fired is None and time.time()-t0<18*3600: + time.sleep(120) + d=latest() + if not alive(): fired=f"ep_rr_ajr EXITED; last {d}"; break + if not d: continue + step,val,best,res,jr=d + if (val>12 or res>0.3) and step>200: fired=f"ep_rr_ajr BLEW @step{step} val{val:.2f} res{res:.1e} best{best:.4f} jr={jr}"; break + if best<=TARGET: fired=f"ep_rr_ajr reached {TARGET}: step{step} best{best:.4f} res{res:.1e} jr={jr} -> 该 eig-probe 看 g 是否~0"; break +print(f"=== EP_RR_AJR -> {TARGET} ==="); print(fired or "18h timeout"); print("last:",latest()) diff --git a/ep_run/ep_c3_watch.py b/ep_run/ep_c3_watch.py new file mode 100644 index 0000000..a81fd44 --- /dev/null +++ b/ep_run/ep_c3_watch.py @@ -0,0 +1,19 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_c3_scratch.log" +def alive(): return subprocess.run(["pgrep","-f","ep_c3_scratch.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + m=re.search(r"step\s+(\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", ls[-1]) + return (int(m.group(1)),float(m.group(2)),float(m.group(3))) if m else None +fired=None; t0=time.time() +while fired is None and time.time()-t0<24*3600: + time.sleep(180) + d=latest() + if not alive(): fired=f"ep_c3 exited; last {d}"; break + if not d: continue + step,val,res=d + if res>0.2 or val>15: fired=f"ep_c3(c=3) DIVERGED step {step} val {val:.2f} res {res:.2e} -> trained-damping did NOT prevent it"; break + if val<2.5: fired=f"ep_c3(c=3) reached val {val:.4f} step {step} -> past redx's ~2.7 blow zone, damping fix HOLDING"; break +print("=== EP_C3 WATCHER ==="); print(fired or "24h timeout"); print("last:", latest()) diff --git a/ep_run/ep_c_check.py b/ep_run/ep_c_check.py new file mode 100644 index 0000000..90892d9 --- /dev/null +++ b/ep_run/ep_c_check.py @@ -0,0 +1,22 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_resreg_c.log"; TARGET=2.30 +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_resreg_c.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)",l); mv=re.search(r"val CE ([\d.eE+-]+)",l); mb=re.search(r"best ([\d.eE+-]+)",l) + mr=re.search(r"res=([\d.eE+-]+)",l); mi=re.search(r"([\d.]+) it/s",l) + if not (ms and mv and mb and mr): return None + return (int(ms.group(1)),float(mv.group(1)),float(mb.group(1)),float(mr.group(1)), mi.group(1) if mi else "?") +fired=None; t0=time.time() +while fired is None and time.time()-t0<12*3600: + time.sleep(90) + d=latest() + if not alive(): fired=f"ep_resreg_c EXITED; last {d}"; break + if not d: continue + step,val,best,res,its=d + if val>12 or res>0.3: fired=f"ep_resreg_c BLEW @step{step} val{val:.2f} res{res:.1e} best{best:.4f} ({its} it/s)"; break + if best<=TARGET: fired=f"ep_resreg_c reached {TARGET}: step{step} best{best:.4f} val{val:.4f} res{res:.1e} -> {its} it/s (vs eager 0.07) = compile 安全+加速确认"; break +print(f"=== EP_RESREG_C -> {TARGET} (+it/s) ==="); print(fired or "12h timeout"); print("last:",latest()) diff --git a/ep_run/ep_eps05_grid.py b/ep_run/ep_eps05_grid.py new file mode 100644 index 0000000..56ea9e9 --- /dev/null +++ b/ep_run/ep_eps05_grid.py @@ -0,0 +1,29 @@ +import time, os, re, subprocess, math +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_eps05.log" +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_eps05.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)", l); mv=re.search(r"val CE ([\d.eE+-]+)", l) + mb=re.search(r"best ([\d.eE+-]+)", l); mr=re.search(r"res=([\d.eE+-]+)", l) + if not (ms and mv and mb and mr): return None + return (int(ms.group(1)), float(mv.group(1)), float(mb.group(1)), float(mr.group(1))) +d0=latest(); B0 = d0[2] if d0 else 2.60 +nb = (math.ceil(B0*10)-1)/10.0 +CLIMB=1.5e-1 +fired=None; best_seen=99.0; best_step=0; hi_res=0.0; t0=time.time() +while fired is None and time.time()-t0 < 12*3600: + time.sleep(60) + d=latest() + if not alive(): fired=f"ep_eps05 EXITED; last {d}"; break + if not d: continue + step,val,best,res = d + if best<best_seen: best_seen=best; best_step=step + if res>hi_res: hi_res=res + if val>10 or res>0.3: fired=f"ep_eps05 DIVERGED/BLEW @step{step} val{val:.2f} res{res:.1e} -> BLOW POINT best={best:.4f} (redx blew 2.74; wall moved to ~{best:.2f})"; break + if res>CLIMB: fired=f"ep_eps05 res IMMINENT-BLOW res{res:.1e} peak{hi_res:.1e} @step{step} best{best:.4f} (climbing toward 0.3 abort)"; break + if best <= nb+1e-9: fired=f"ep_eps05 reached {nb:.2f}: step{step} best{best:.4f} val{val:.4f} res{res:.1e} peak_res{hi_res:.2f}"; break + if step-best_step >= 1800 and best>nb: fired=f"ep_eps05 STALLED above {nb:.2f}: best{best:.4f} no-improve {step-best_step} steps res{res:.1e}"; break +print(f"=== EP_EPS05 GRID (target {nb:.2f}, climb>{CLIMB}) ==="); print(fired or "12h timeout"); print("last:", latest(), "peak_res", round(hi_res,4)) diff --git a/ep_run/ep_eps05_track.py b/ep_run/ep_eps05_track.py new file mode 100644 index 0000000..b888e24 --- /dev/null +++ b/ep_run/ep_eps05_track.py @@ -0,0 +1,24 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_eps05.log" +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_eps05.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)", l); mv=re.search(r"val CE ([\d.eE+-]+)", l) + mb=re.search(r"best ([\d.eE+-]+)", l); mr=re.search(r"res=([\d.eE+-]+)", l) + if not (ms and mv and mb and mr): return None + return (int(ms.group(1)), float(mv.group(1)), float(mb.group(1)), float(mr.group(1))) +fired=None; best_seen=99.0; best_step=0; t0=time.time() +while fired is None and time.time()-t0 < 8*3600: + time.sleep(120) + d=latest() + if not alive(): fired=f"ep_eps05 EXITED; last {d}"; break + if not d: continue + step,val,best,res = d + if best < best_seen: best_seen=best; best_step=step + if res>0.3 or val>15: fired=f"ep_eps05 DIVERGED @step{step} val{val:.2f} res{res:.1e} -> blew BETWEEN 3.13 and 2.74; smaller eps did NOT fully fix it"; break + if best <= 2.74: fired=f"ep_eps05 CROSSED 2.74 (redx blow point): step{step} best{best:.4f} val{val:.4f} res{res:.1e} -> integration fix carried it PAST redx wall"; break + if step-best_step >= 1200 and best > 2.74: fired=f"ep_eps05 STALLED: best{best:.4f}, no improve {step-best_step} steps, res{res:.1e} (not blowing) -> eps=0.05 wall ~{best:.2f}?"; break +print("=== EP_EPS05 TRACK (->2.74) ==="); print(fired or "8h timeout"); print("last:", latest()) diff --git a/ep_run/ep_eps05_track2.py b/ep_run/ep_eps05_track2.py new file mode 100644 index 0000000..721b712 --- /dev/null +++ b/ep_run/ep_eps05_track2.py @@ -0,0 +1,26 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_eps05.log" +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_eps05.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)", l); mv=re.search(r"val CE ([\d.eE+-]+)", l) + mb=re.search(r"best ([\d.eE+-]+)", l); mr=re.search(r"res=([\d.eE+-]+)", l) + if not (ms and mv and mb and mr): return None + return (int(ms.group(1)), float(mv.group(1)), float(mb.group(1)), float(mr.group(1))) +fired=None; best_seen=99.0; best_step=0; hi_res=0.0; t0=time.time() +while fired is None and time.time()-t0 < 12*3600: + time.sleep(120) + d=latest() + if not alive(): fired=f"ep_eps05 EXITED; last {d}"; break + if not d: continue + step,val,best,res = d + if best < best_seen: best_seen=best; best_step=step + if res>hi_res: hi_res=res + if val>10 or res>0.3: fired=f"ep_eps05 DIVERGED @step{step} val{val:.2f} res{res:.1e} -> wall MOVED LOWER: eps=0.05 blew ~{best:.2f} (between 2.74 and 2.09)"; break + if best <= 2.09: fired=f"ep_eps05 PASSED 2.09 (early wall)! step{step} best{best:.4f} val{val:.4f} res{res:.1e} -> integration fix beats the early wall"; break + if res>4e-2: fired=f"ep_eps05 res CLIMBING: res{res:.1e} (peak {hi_res:.1e}) @step{step} best{best:.4f} -> early warning, wall may be moving lower (NOT blown yet)"; break + if step-best_step >= 1500 and best>2.09: fired=f"ep_eps05 STALLED: best{best:.4f} no-improve {step-best_step} steps res{res:.1e} -> eps=0.05 floor ~{best:.2f}"; break +print("=== EP_EPS05 TRACK2 (2.74->2.09) ==="); print(fired or "12h timeout"); print("last:", latest(), "peak_res", round(hi_res,4)) diff --git a/ep_run/ep_eps05_watch.py b/ep_run/ep_eps05_watch.py new file mode 100644 index 0000000..075e631 --- /dev/null +++ b/ep_run/ep_eps05_watch.py @@ -0,0 +1,20 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_eps05.log" +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_eps05.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + m=re.search(r"step\s+(\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", ls[-1]) + return (int(m.group(1)),float(m.group(2)),float(m.group(3))) if m else None +fired=None; t0=time.time() +while fired is None and time.time()-t0<18*3600: + time.sleep(120) + d=latest() + if not alive(): fired=f"ep_eps05 EXITED; last {d}"; break + if not d: continue + step,val,res=d + if res>0.2 or val>15: fired=f"ep_eps05 DIVERGED step {step} val {val:.2f} res {res:.2e} -> smaller eps did NOT fix it (true continuous instability, not just Euler)"; break + if val<2.5: fired=f"ep_eps05 reached val {val:.4f} step {step} res {res:.2e} -> CLEARED past redx's 2.74 blow point (integration fix WORKS)"; break + if step>=3500: fired=f"ep_eps05 SURVIVED to step {step} val {val:.4f} res {res:.2e} -> past the blow zone (fix holding)"; break +print("=== EP_EPS05 WATCHER ==="); print(fired or "18h timeout"); print("last:", latest()) diff --git a/ep_run/ep_fast_check.py b/ep_run/ep_fast_check.py new file mode 100644 index 0000000..fe82955 --- /dev/null +++ b/ep_run/ep_fast_check.py @@ -0,0 +1,23 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_resreg_fast.log"; TARGET=2.30 +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_resreg_fast.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)",l); mv=re.search(r"val CE ([\d.eE+-]+)",l) + mb=re.search(r"best ([\d.eE+-]+)",l); mr=re.search(r"res=([\d.eE+-]+)",l); mj=re.search(r"jr=([\d.eE+-]+)",l) + if not (ms and mv and mb and mr): return None + return (int(ms.group(1)),float(mv.group(1)),float(mb.group(1)),float(mr.group(1)),float(mj.group(1)) if mj else 0.0) +fired=None; t0=time.time(); hires=0.0 +while fired is None and time.time()-t0<12*3600: + time.sleep(90) + d=latest() + if not alive(): fired=f"ep_resreg_fast EXITED; last {d}"; break + if not d: continue + step,val,best,res,jr=d + hires=max(hires,res) + if val>12 or res>0.3: fired=f"ep_resreg_fast BLEW @step{step} val{val:.2f} res{res:.1e} best{best:.4f} peakres{hires:.1e}"; break + if best<=TARGET: fired=f"ep_resreg_fast reached {TARGET}: step{step} best{best:.4f} val{val:.4f} res{res:.1e} peakres{hires:.1e} -> resreg broke past the jacreg-stall region"; break +print(f"=== EP_RESREG_FAST -> {TARGET} ==="); print(fired or "12h timeout"); print("last:",latest()) diff --git a/ep_run/ep_fast_timing.py b/ep_run/ep_fast_timing.py new file mode 100644 index 0000000..942f81d --- /dev/null +++ b/ep_run/ep_fast_timing.py @@ -0,0 +1,20 @@ +import time, os, re +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_resreg_fast.log" +def step(): + try: ls=[l for l in open(LOG) if l.startswith("step")]; return int(re.search(r"step\s+(\d+)",ls[-1]).group(1)) if ls else 0 + except Exception: return 0 +def vrb(): + try: l=[x for x in open(LOG) if x.startswith("step")][-1] + except Exception: return None + v=re.search(r"val CE ([\d.]+)",l); r=re.search(r"res=([\d.eE+-]+)",l); b=re.search(r"best ([\d.]+)",l) + return (float(v.group(1)) if v else 0, float(r.group(1)) if r else 0, float(b.group(1)) if b else 0) +t0=time.time() +while step()<100 and time.time()-t0<900: time.sleep(15) # 等过编译预热 +s1=step(); t1=time.time(); time.sleep(180); s2=step(); t2=time.time() +ds=s2-s1; dt=t2-t1 +print("=== EP_RESREG_FAST compile+TF32 计时 ===") +if ds>0: + print(f"{ds} steps / {dt:.0f}s = {dt/ds:.2f} s/step = {60*ds/dt:.1f} steps/min") + print(f"对照:其他 eager run ~13.6 s/step(但那是 t2sel160;此为 t2sel40+compile+tf32)") + print(f"sanity (val,res,best): {vrb()} <- res 没爆+在降 = TF32 没破稳定") +else: print("no progress (还在编译/卡住), step", step()) diff --git a/ep_run/ep_jacreg_binary.py b/ep_run/ep_jacreg_binary.py new file mode 100644 index 0000000..f5b0b1e --- /dev/null +++ b/ep_run/ep_jacreg_binary.py @@ -0,0 +1,22 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_jacreg.log"; TARGET=2.30 +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_jacreg.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)",l); mv=re.search(r"val CE ([\d.]+)",l); mb=re.search(r"best ([\d.]+)",l) + mr=re.search(r"res=([\d.eE+-]+)",l); mj=re.search(r"jr=([\d.eE+-]+)",l) + if not(ms and mv and mb and mr): return None + return int(ms.group(1)),float(mv.group(1)),float(mb.group(1)),float(mr.group(1)),float(mj.group(1)) if mj else 0.0 +fired=None; t0=time.time() +while fired is None and time.time()-t0<12*3600: + time.sleep(90) + if not alive(): fired=f"EXITED last={latest()}"; break + d=latest() + if not d: continue + step,val,best,res,jr=d + if res>0.3 or val>12: fired=f"ep_jacreg BLEW @step{step} val{val:.2f} res{res:.1e} jr{jr:.1f} (best had been {best:.4f})"; break + if best<=TARGET: fired=f"ep_jacreg reached {TARGET} @step{step} best {best:.4f} val{val:.3f} res{res:.1e} jr{jr:.1f}"; break +print(f"=== EP_JACREG -> {TARGET} (blow or descend) ==="); print(fired or "12h timeout still between 2.40-2.30"); print("last:",latest()) diff --git a/ep_run/ep_jacreg_grid.py b/ep_run/ep_jacreg_grid.py new file mode 100644 index 0000000..8be0c63 --- /dev/null +++ b/ep_run/ep_jacreg_grid.py @@ -0,0 +1,27 @@ +import time, os, re, subprocess, math +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_jacreg.log" +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_jacreg.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)",l); mv=re.search(r"val CE ([\d.eE+-]+)",l) + mb=re.search(r"best ([\d.eE+-]+)",l); mr=re.search(r"res=([\d.eE+-]+)",l); mj=re.search(r"jr=([\d.eE+-]+)",l) + if not (ms and mv and mb and mr): return None + return (int(ms.group(1)),float(mv.group(1)),float(mb.group(1)),float(mr.group(1)),float(mj.group(1)) if mj else 0.0) +d0=latest(); B0=d0[2] if d0 else 3.13 +nb=(math.ceil(B0*10)-1)/10.0 +fired=None; best_seen=99.0; best_step=0; t0=time.time() +while fired is None and time.time()-t0<12*3600: + time.sleep(90) + d=latest() + if not alive(): fired=f"ep_jacreg EXITED; last {d}"; break + if not d: continue + step,val,best,res,jr=d + if best<best_seen: best_seen=best; best_step=step + if val>10 or res>0.3: fired=f"ep_jacreg DIVERGED @step{step} val{val:.2f} res{res:.1e} jr={jr:.1f} best{best:.4f} -> adaptive jacreg did NOT hold"; break + if jr>=15.5 and res>5e-3: fired=f"ep_jacreg jr SATURATED @step{step} jr={jr:.1f}(max16) res={res:.1e} best{best:.4f} -> controller maxed, can't hold (jacreg insufficient/over-tax)"; break + if best<=nb+1e-9: fired=f"ep_jacreg reached {nb:.2f}: step{step} best{best:.4f} val{val:.4f} res{res:.1e} jr={jr:.2f}"; break + if step-best_step>=1500 and best>nb: fired=f"ep_jacreg STALLED above {nb:.2f}: best{best:.4f} no-improve {step-best_step} res{res:.1e} jr={jr:.2f}"; break +print(f"=== EP_JACREG GRID (target {nb:.2f}) ==="); print(fired or "12h timeout"); print("last:",latest()) diff --git a/ep_run/ep_jacreg_spike.py b/ep_run/ep_jacreg_spike.py new file mode 100644 index 0000000..a72b89c --- /dev/null +++ b/ep_run/ep_jacreg_spike.py @@ -0,0 +1,26 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_jacreg.log" +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_jacreg.pt"],capture_output=True).returncode==0 +def steps(): + out=[] + try: + for l in open(LOG): + if not l.startswith("step"): continue + ms=re.search(r"step\s+(\d+)",l); mv=re.search(r"val CE ([\d.]+)",l) + mj=re.search(r"jr=([\d.eE+-]+)",l); mr=re.search(r"res=([\d.eE+-]+)",l) + if ms and mv and mj: out.append((int(ms.group(1)),float(mv.group(1)),float(mj.group(1)),float(mr.group(1)) if mr else 0)) + except Exception: pass + return out +fired=None; t0=time.time(); hi=0; prev=None; seen=set([r[0] for r in steps()]) # ignore already-seen (incl the 6250 spike) +while fired is None and time.time()-t0<4*3600: + time.sleep(60) + if not alive(): fired=f"EXITED last={steps()[-1] if steps() else None}"; break + for r in steps(): + if r[0] in seen: continue + seen.add(r[0]); step,val,jr,res=r + if val>15 or res>0.3: fired=f"DIVERGED @{step} val{val:.2f} res{res:.1e} jr{jr:.1f}"; break + hi = hi+1 if jr>=8 else 0 + if jr<2.0 and val<2.55: fired=f"SUPPRESSED @{step}: jr relaxed to {jr:.1f}, CE recovered {val:.3f} (best2.4381) res{res:.1e} -> controller WON the spike"; break + if hi>=3: fired=f"jr SATURATING @{step}: jr>=8 for {hi} logged-steps (now {jr:.1f}), val{val:.3f} res{res:.1e} -> controller maxed, not relaxing (early hijack/saturation)"; break + if fired: break +print("=== EP_JACREG SPIKE-RECOVERY ==="); print(fired or "4h timeout"); print("last5:", steps()[-5:]) diff --git a/ep_run/ep_resreg_check.py b/ep_run/ep_resreg_check.py new file mode 100644 index 0000000..cad761b --- /dev/null +++ b/ep_run/ep_resreg_check.py @@ -0,0 +1,23 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_resreg_warm.log"; TARGET=2.05 +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_resreg_warm.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)",l); mv=re.search(r"val CE ([\d.eE+-]+)",l) + mb=re.search(r"best ([\d.eE+-]+)",l); mr=re.search(r"res=([\d.eE+-]+)",l); mj=re.search(r"jr=([\d.eE+-]+)",l) + if not (ms and mv and mb and mr): return None + return (int(ms.group(1)),float(mv.group(1)),float(mb.group(1)),float(mr.group(1)),float(mj.group(1)) if mj else 0.0) +fired=None; t0=time.time(); hires=0.0 +while fired is None and time.time()-t0<14*3600: + time.sleep(90) + d=latest() + if not alive(): fired=f"ep_resreg_warm EXITED; last {d}"; break + if not d: continue + step,val,best,res,jr=d + hires=max(hires,res) + if val>12 or res>0.3: fired=f"ep_resreg_warm BLEW @step{step} val{val:.2f} res{res:.1e} best{best:.4f} peakres{hires:.1e}"; break + if best<=TARGET: fired=f"ep_resreg_warm reached {TARGET}: step{step} best{best:.4f} val{val:.4f} res{res:.1e} jr{jr:.2f} peakres_sofar{hires:.1e}"; break +print(f"=== EP_RESREG_WARM -> {TARGET} ==="); print(fired or "14h timeout"); print("last:",latest()) diff --git a/ep_run/ep_resreg_grid.py b/ep_run/ep_resreg_grid.py new file mode 100644 index 0000000..fc9fb1c --- /dev/null +++ b/ep_run/ep_resreg_grid.py @@ -0,0 +1,24 @@ +import time, os, re, subprocess, math +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_resreg_warm.log" +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_resreg_warm.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)",l); mv=re.search(r"val CE ([\d.eE+-]+)",l) + mb=re.search(r"best ([\d.eE+-]+)",l); mr=re.search(r"res=([\d.eE+-]+)",l); mj=re.search(r"jr=([\d.eE+-]+)",l) + if not (ms and mv and mb and mr): return None + return (int(ms.group(1)),float(mv.group(1)),float(mb.group(1)),float(mr.group(1)),float(mj.group(1)) if mj else 0.0) +d0=latest(); B0=d0[2] if d0 else 3.13 +nb=(math.ceil(B0*10)-1)/10.0 +fired=None; t0=time.time() +while fired is None and time.time()-t0<14*3600: + time.sleep(90) + d=latest() + if not alive(): fired=f"ep_resreg_warm EXITED; last {d}"; break + if not d: continue + step,val,best,res,jr=d + if val>12 or res>0.3: fired=f"ep_resreg_warm BLEW @step{step} val{val:.2f} res{res:.1e} best{best:.4f}"; break + if best<=nb+1e-9: fired=f"ep_resreg_warm reached {nb:.2f}: step{step} best{best:.4f} val{val:.4f} res{res:.1e} jr{jr:.2f}"; break +print(f"=== EP_RESREG_WARM GRID (target {nb:.2f}) ==="); print(fired or "14h timeout"); print("last:",latest()) diff --git a/ep_run/ep_rr_check.py b/ep_run/ep_rr_check.py new file mode 100644 index 0000000..a820545 --- /dev/null +++ b/ep_run/ep_rr_check.py @@ -0,0 +1,22 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_rr_scratch.log"; TARGET=2.70 +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_rr_scratch.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + l=ls[-1] + ms=re.search(r"step\s+(\d+)",l); mv=re.search(r"val CE ([\d.eE+-]+)",l); mb=re.search(r"best ([\d.eE+-]+)",l) + mr=re.search(r"res=([\d.eE+-]+)",l); mi=re.search(r"([\d.]+) it/s",l) + if not (ms and mv and mb and mr): return None + return (int(ms.group(1)),float(mv.group(1)),float(mb.group(1)),float(mr.group(1)), mi.group(1) if mi else "?") +fired=None; t0=time.time() +while fired is None and time.time()-t0<23*3600: + time.sleep(120) + d=latest() + if not alive(): fired=f"ep_rr_scratch EXITED; last {d}"; break + if not d: continue + step,val,best,res,its=d + if (val>12 or res>0.3) and step>400: fired=f"ep_rr_scratch BLEW @step{step} val{val:.2f} res{res:.1e} best{best:.4f} ({its} it/s)"; break + if best<=TARGET: fired=f"ep_rr_scratch reached {TARGET}(from-scratch 干净降入墙区): step{step} best{best:.4f} val{val:.4f} res{res:.1e} {its}it/s"; break +print(f"=== EP_RR_SCRATCH -> {TARGET} ==="); print(fired or "23h timeout"); print("last:",latest()) diff --git a/ep_run/ep_sn_monitor.py b/ep_run/ep_sn_monitor.py new file mode 100644 index 0000000..523fd6d --- /dev/null +++ b/ep_run/ep_sn_monitor.py @@ -0,0 +1,43 @@ +"""Combined monitor for ep_sn (hr=0.2 + specnorm 0.9): watch loss + probe cos per ckpt. +Fire on: diverge / val<2.0 (cleared wall) / cos<0.82 (specnorm not holding gradient) / exit.""" +import time, os, re, subprocess, shutil +WD = "/home/yurenh2/ept/ep_run"; os.chdir(WD) +LOG, CK, FROZEN, COSLOG, PID = "runs/ep_sn.log", "runs/ep_sn.pt", "runs/ep_sn_cosprobe.pt", "runs/cos_monitor_sn.log", 2428946 +BLOG = "runs/bptt_clean.log" +def alive(p): + try: os.kill(p, 0); return True + except Exception: return False +def latest(log): + try: ls = [l for l in open(log) if l.startswith("step")] + except Exception: return None + if not ls: return None + m = re.search(r"step (\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", ls[-1]) + return (int(m.group(1)), float(m.group(2)), float(m.group(3))) if m else None +open(COSLOG, "a").write("# ep_sn monitor (hr=0.2 + specnorm 0.9)\n") +last = -1; fired = None; t0 = time.time() +while fired is None and time.time() - t0 < 18 * 3600: + time.sleep(120) + if not alive(PID): fired = f"ep_sn EXITED; last {latest(LOG)}"; break + d = latest(LOG) + if not d: continue + step, val, res = d + if res > 0.2 or val > 15: fired = f"ep_sn DIVERGED step {step} val {val:.2f} res {res:.2e} — specnorm did NOT prevent it"; break + if val < 2.0: fired = f"ep_sn reached val {val:.4f} step {step} res {res:.2e} — CLEARED the wall (hr+specnorm worked)"; break + if step >= last + 450 and os.path.exists(CK) and os.path.getsize(CK) > 1e6: + try: shutil.copy2(CK, FROZEN) + except Exception: continue + last = step + env = dict(os.environ, CUDA_VISIBLE_DEVICES="0", PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True") + cosv = "?" + try: + r = subprocess.run(["python3", "asym_probe.py", "--ckpt", FROZEN, "--B", "8"], + env=env, capture_output=True, text=True, timeout=600) + m = re.search(r"cos\(g_EP, ?g_transpose\)=([+-][0-9.]+)", r.stdout + r.stderr) + cosv = float(m.group(1)) if m else "?" + except Exception as e: cosv = f"err:{e}" + line = f"step {step}: cos={cosv} val={val:.4f} res={res:.2e}" + open(COSLOG, "a").write(line + "\n"); print(line, flush=True) + if isinstance(cosv, float) and cosv < 0.82: + fired = f"ep_sn COS DEGRADED to {cosv:.3f} step {step} (res {res:.2e}) — specnorm not holding the gradient"; break +print("=== EP_SN MONITOR FIRED ==="); print("trigger:", fired or "18h timeout") +print("ep_sn:", latest(LOG), "| bptt:", latest(BLOG)) diff --git a/ep_run/ep_t2fix_watch.py b/ep_run/ep_t2fix_watch.py new file mode 100644 index 0000000..32346f1 --- /dev/null +++ b/ep_run/ep_t2fix_watch.py @@ -0,0 +1,20 @@ +import time, os, re, subprocess +os.chdir("/home/yurenh2/ept/ep_run"); LOG="runs/ep_t2fix.log" +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_t2fix.pt"],capture_output=True).returncode==0 +def latest(): + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: return None + if not ls: return None + m=re.search(r"step\s+(\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", ls[-1]) + return (int(m.group(1)),float(m.group(2)),float(m.group(3))) if m else None +fired=None; t0=time.time() +while fired is None and time.time()-t0<18*3600: + time.sleep(120) + d=latest() + if not alive(): fired=f"ep_t2fix EXITED; last {d}"; break + if not d: continue + step,val,res=d + if res>0.2 or val>15: fired=f"ep_t2fix DIVERGED step {step} val {val:.2f} res {res:.2e} — t2sel=160 did NOT prevent it"; break + if val<2.0: fired=f"ep_t2fix reached val {val:.4f} step {step} res {res:.2e} — CLEARED the 2.0 wall (t2sel fix WORKED)"; break + if step>=3200: fired=f"ep_t2fix SURVIVED to step {step} val {val:.4f} res {res:.2e} — past the redx/ep_hr02 blowup zone (fix holding)"; break +print("=== EP_T2FIX WATCHER FIRED ==="); print("trigger:", fired or "18h timeout"); print("last:", latest()) diff --git a/ep_run/epmc.json b/ep_run/epmc.json new file mode 100644 index 0000000..9237c6e --- /dev/null +++ b/ep_run/epmc.json @@ -0,0 +1 @@ +{"version":"6.9","hitCount":1,"request":{"queryString":"TITLE:\"Memristor Crossbar Circuits Implementing Equilibrium Propagation for On-Device Learning\"","resultType":"core","cursorMark":"*","pageSize":25,"sort":"","synonym":false},"resultList":{"result":[{"id":"37512678","source":"MED","pmid":"37512678","pmcid":"PMC10384638","fullTextIdList":{"fullTextId":["PMC10384638"]},"doi":"10.3390/mi14071367","title":"Memristor Crossbar Circuits Implementing Equilibrium Propagation for On-Device Learning.","authorString":"Oh S, An J, Cho S, Yoon R, Min KS.","authorList":{"author":[{"fullName":"Oh S","firstName":"Seokjin","lastName":"Oh","initials":"S","authorAffiliationDetailsList":{"authorAffiliation":[{"affiliation":"School of Electrical Engineering, Kookmin University, Seoul 02707, Republic of Korea."}]}},{"fullName":"An J","firstName":"Jiyong","lastName":"An","initials":"J","authorId":{"type":"ORCID","value":"0000-0003-3793-028X"},"authorAffiliationDetailsList":{"authorAffiliation":[{"affiliation":"School of Electrical Engineering, Kookmin University, Seoul 02707, Republic of Korea."}]}},{"fullName":"Cho S","firstName":"Seungmyeong","lastName":"Cho","initials":"S","authorId":{"type":"ORCID","value":"0009-0008-4577-0128"},"authorAffiliationDetailsList":{"authorAffiliation":[{"affiliation":"School of Electrical Engineering, Kookmin University, Seoul 02707, Republic of Korea."}]}},{"fullName":"Yoon R","firstName":"Rina","lastName":"Yoon","initials":"R","authorAffiliationDetailsList":{"authorAffiliation":[{"affiliation":"School of Electrical Engineering, Kookmin University, Seoul 02707, Republic of Korea."}]}},{"fullName":"Min KS","firstName":"Kyeong-Sik","lastName":"Min","initials":"KS","authorId":{"type":"ORCID","value":"0000-0002-1518-7037"},"authorAffiliationDetailsList":{"authorAffiliation":[{"affiliation":"School of Electrical Engineering, Kookmin University, Seoul 02707, Republic of Korea."}]}}]},"authorIdList":{"authorId":[{"type":"ORCID","value":"0000-0002-1518-7037"},{"type":"ORCID","value":"0000-0003-3793-028X"},{"type":"ORCID","value":"0009-0008-4577-0128"}]},"journalInfo":{"issue":"7","volume":"14","journalIssueId":3607342,"dateOfPublication":"2023 Jul","monthOfPublication":7,"yearOfPublication":2023,"printPublicationDate":"2023-07-01","journal":{"title":"Micromachines","medlineAbbreviation":"Micromachines (Basel)","issn":"2072-666X","nlmid":"101640903","essn":"2072-666X","isoabbreviation":"Micromachines (Basel)"}},"pubYear":"2023","pageInfo":"1367","abstractText":"Equilibrium propagation (EP) has been proposed recently as a new neural network training algorithm based on a local learning concept, where only local information is used to calculate the weight update of the neural network. Despite the advantages of local learning, numerical iteration for solving the EP dynamic equations makes the EP algorithm less practical for realizing edge intelligence hardware. Some analog circuits have been suggested to solve the EP dynamic equations physically, not numerically, using the original EP algorithm. However, there are still a few problems in terms of circuit implementation: for example, the need for storing the free-phase solution and the lack of essential peripheral circuits for calculating and updating synaptic weights. Therefore, in this paper, a new analog circuit technique is proposed to realize the EP algorithm in practical and implementable hardware. This work has two major contributions in achieving this objective. First, the free-phase and nudge-phase solutions are calculated by the proposed analog circuits simultaneously, not at different times. With this process, analog voltage memories or digital memories with converting circuits between digital and analog domains for storing the free-phase solution temporarily can be eliminated in the proposed EP circuit. Second, a simple EP learning rule relying on a fixed amount of conductance change per programming pulse is newly proposed and implemented in peripheral circuits. The modified EP learning rule can make the weight update circuit practical and implementable without requiring the use of a complicated program verification scheme. The proposed memristor conductance update circuit is simulated and verified for training synaptic weights on memristor crossbars. The simulation results showed that the proposed EP circuit could be used for realizing on-device learning in edge intelligence hardware.","affiliation":"School of Electrical Engineering, Kookmin University, Seoul 02707, Republic of Korea.","publicationStatus":"epublish","language":"eng","pubModel":"Electronic","pubTypeList":{"pubType":["research-article","Journal Article"]},"grantsList":{"grant":[{"grantId":"2021R1A2C1011631","agency":"NRF","orderIn":0},{"grantId":"2021M3F3A2A01037972","agency":"NRF","orderIn":0},{"grantId":"NRF-2022R1A5A7000765","agency":"National Research Foundation of Korea","orderIn":0},{"grantId":"NRF-2021M3F3A2A01037972","agency":"National Research Foundation of Korea","orderIn":0},{"grantId":"NRF-2021R1A2C1011631","agency":"National Research Foundation of Korea","orderIn":0},{"grantId":"2022R1A5A7000765","agency":"NRF","orderIn":0}]},"keywordList":{"keyword":["Local Learning","On-device Learning","Equilibrium Propagation","Memristor Crossbar Circuits"]},"fullTextUrlList":{"fullTextUrl":[{"availability":"Open access","availabilityCode":"OA","documentStyle":"pdf","site":"Unpaywall","url":"https://www.mdpi.com/2072-666X/14/7/1367/pdf?version=1688377497"},{"availability":"Subscription required","availabilityCode":"S","documentStyle":"doi","site":"DOI","url":"https://doi.org/10.3390/mi14071367"},{"availability":"Open access","availabilityCode":"OA","documentStyle":"html","site":"Europe_PMC","url":"https://europepmc.org/articles/PMC10384638"},{"availability":"Open access","availabilityCode":"OA","documentStyle":"pdf","site":"Europe_PMC","url":"https://europepmc.org/articles/PMC10384638?pdf=render"}]},"isOpenAccess":"Y","inEPMC":"Y","inPMC":"Y","hasPDF":"Y","hasBook":"N","hasSuppl":"N","citedByCount":2,"hasData":"N","hasReferences":"Y","hasTextMinedTerms":"Y","hasDbCrossReferences":"N","hasLabsLinks":"N","license":"cc by","hasEvaluations":"N","authMan":"N","epmcAuthMan":"N","nihAuthMan":"N","hasTMAccessionNumbers":"N","dateOfCreation":"2023-07-29","firstIndexDate":"2023-07-30","fullTextReceivedDate":"2023-08-20","dateOfRevision":"2023-08-01","electronicPublicationDate":"2023-07-03","firstPublicationDate":"2023-07-03"}]}}
\ No newline at end of file diff --git a/ep_run/eps_sweep_s3200.py b/ep_run/eps_sweep_s3200.py new file mode 100644 index 0000000..3c26d73 --- /dev/null +++ b/ep_run/eps_sweep_s3200.py @@ -0,0 +1,29 @@ +import torch, pickle, math +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=8; T=256; Ttot=400.0 # fixed "time" budget eps*N=Ttot so all eps cover same settling +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev) +def relax_eps(eps): + N=min(int(Ttot/eps), 24000) + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True + with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + xin=blk.embed(idx).detach(); z=xin.clone(); ress=[] + for t in range(N): + z2=z+eps*blk.force(z,xin).detach() + r=(z2-z).norm().item()/(z.norm().item()+1e-9); ress.append(r); z=z2 + if not math.isfinite(r) or r>1e3: return ('DIVERGED',t,r,r) + tail=ress[-min(800,N//4):] + return (N, ress[-1], min(tail), max(tail)) +print("=== eps-sweep on redx s3200 (FULL attention, the cycling operator): Euler-artifact vs continuous instability ===") +print("eps*N=400 held fixed (same settling time). Cycle dies as eps shrinks => DISCRETE-EULER ARTIFACT (continuous ODE / analog HW is fine).") +for eps in [0.1, 0.05, 0.03, 0.02, 0.01]: + r=relax_eps(eps) + if r[0]=='DIVERGED': print(f" eps={eps}: DIVERGED at t={r[1]} r={r[2]:.2e}") + else: + N,last,tmin,tmax=r; osc=tmax-tmin + print(f" eps={eps}: N={N:5d} res(last)={last:.3e} tail[min={tmin:.2e},max={tmax:.2e}] osc={osc:.2e} {'CYCLE' if (osc>5e-4 and last>2e-3) else 'CONVERGED' if last<2e-3 else 'floored'}") +print("=== DONE ===") diff --git a/ep_run/eval_relax_s3200.py b/ep_run/eval_relax_s3200.py new file mode 100644 index 0000000..49a8316 --- /dev/null +++ b/ep_run/eval_relax_s3200.py @@ -0,0 +1,23 @@ +import torch, pickle, math +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; eps=0.1; B=8; T=256; N=6000 +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True +ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev) +with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + xin=blk.embed(idx).detach(); z=xin.clone(); ress=[] + for t in range(N): + z2=z+eps*blk.force(z,xin).detach() + r=(z2-z).norm().item()/(z.norm().item()+1e-9); ress.append(r); z=z2 + if not math.isfinite(r) or r>1e3: print(f"DIVERGED at t={t} r={r:.2e}"); break +print("=== eval_relax redx s3200 (marginal, val 2.74) : CONVERGE-slow (rho<1) or LIMIT-CYCLE (floor/oscillate)? ===") +for t in [50,150,500,1000,2000,4000,5999]: + if t<len(ress): print(f" res(t={t:4d}) = {ress[t]:.3e}") +tail=ress[-1000:] if len(ress)>=1000 else ress +mono=all(tail[i]>=tail[i+1]-1e-12 for i in range(len(tail)-1)) +print(f" tail(last1000): min={min(tail):.2e} max={max(tail):.2e} last={ress[-1]:.2e} monotone_decreasing={mono}") +print(" VERDICT: res->~1e-5 monotone => SLOW CONVERGENCE (rho<1, finite-horizon budget); floored ~1e-2 + non-monotone => LIMIT CYCLE (forward non-convergence)") diff --git a/ep_run/extracted_paper.txt b/ep_run/extracted_paper.txt new file mode 100644 index 0000000..4f521d8 --- /dev/null +++ b/ep_run/extracted_paper.txt @@ -0,0 +1,2039 @@ + Photonic Exponential Approximation via Cascaded TFLN Microring Resonators + toward Softmax + Hyoseok Park1 and Yeonsang Park1, ∗ + 1 + Department of Physics, Chungnam National University, Daejeon 34134, Republic of Korea + (Dated: March 26, 2026) + The rapid growth of large-scale AI models has intensified energy consumption and data-movement + challenges in modern datacenters. Photonic accelerators offer a promising path by executing the linear + matrix multiplications of transformer inference at high throughput and low energy. However, the + softmax attention layer—which requires element-wise exponentiation followed by normalization—still + relies on electronic post-processing, creating an electro-optic conversion bottleneck that negates much + of the potential photonic advantage. +arXiv:2603.12934v3 [physics.optics] 25 Mar 2026 + + + + + We present a cascaded micro-ring resonator (MRR) architecture that synthesizes the per-channel + exponential function required by softmax, exn −max(x) , over a finite interval with tunable worst-case + relative error. A control signal detunes each ring via an electro-optic mechanism; a weak probe + at fixed frequency experiences Lorentzian transmission, and cascading N identical stages yields a + multiplicative transfer function whose logarithm is approximately linear. + We derive mapping rules, depth-scaling estimates, and a minimax fitting formulation, and validate + the framework with three-dimensional FDTD simulations of X-cut thin-film lithium niobate (TFLN) + add-drop micro-ring resonators. Direct multi-ring FDTD validation extends to a five-ring cascade + and confirms agreement with theory primarily over the upper operating range; deeper cascades and + higher quality factors are assessed analytically. The cascade implements the per-channel exponential + block—the key missing nonlinearity for photonic softmax. We further present a WDM-parallel + chip architecture with closed-loop PI feedback that completes the full softmax—exponentiation, + summation, and normalization—on a single photonic chip without per-channel normalization circuitry. + + + I. INTRODUCTION is approximately linear over a finite interval, enabling + exponential-function synthesis with sub-2% worst-case + Transformer inference is often limited by power and error—an order of magnitude more accurate than SOFT- + memory traffic, motivating optical accelerators that ex- ONIC’s polynomial approach—while remaining compati- + ploit parallel propagation and multiplexing [1, 2, 4, 5, 7, 9]. ble with integrated microring platforms [20–24]. We term + Recent perspective articles also discuss data-center power this cascade block an approximate exponential function + consumption as one motivation for optical comput- (AEF) unit. We further propose a WDM-parallel archi- + ing [3, 8]. While linear operators are comparatively tecture with a single PI feedback loop that realizes the + amenable to photonic implementation [4–6], the softmax complete softmax function—including summation and + function used in attention layers requires an exponen- normalization—without per-channel electronic process- + tial mapping together with global normalization—both ing. + difficult to realize in passive photonic circuits, where We extend the theoretical framework with three- + transmission is fundamentally bounded by unity. Parallel dimensional FDTD simulations of a single X-cut TFLN + digital-hardware studies treat the exponential/softmax add-drop micro-ring resonator. The simulated device + stage as a bottleneck and propose dedicated approxima- parameters—quality factor, free spectral range, and + tions [11–19]. Many integrated-photonic classifier demon- electro-optic sensitivity—calibrate the cascade design pa- + strations still rely on electronic post-processing for the rameters, bridging analytical fitting and physically realiz- + final nonlinear readout [10]; the resulting electro-optic able hardware. Two operating regimes emerge from this + conversion overhead can negate the throughput and en- calibration: an FDTD-characterized regime with moder- + ergy benefits of the photonic front-end. Notably, the ate drop-port depth (Dmax ≈ 0.36), where the analytic + SOFTONIC architecture [11] explicitly argues that “the error stays below ∼5% for N ≤ 7 but the power bud- + inability of MRRs and MZMs to handle SMA’s expo- get limits practical cascades to N ≤ 5; and a projected + nential and division functions” necessitates alternative high-Q regime (Dmax ≥ 0.95), enabling deeper cascades + approaches based on microdisk modulators and polyno- (N ≤ 30) with sub-percent error. Cascade performance is + mial approximation, achieving 89.7% accuracy with a predicted analytically and validated by a five-ring cascade + third-degree Chebyshev polynomial. Here we challenge 3D FDTD simulation (Sec. IV). + this premise: we show that a passive Lorentzian cascade The paper is organized as follows: Section II presents + of microring resonators can be tuned so that its logarithm the mapping, transfer model, and depth-design rules; Sec- + tion III provides numerical fits and validation; Section IV + describes the single-ring TFLN device design and FDTD + validation; Section V assesses physical feasibility including + ∗ yeonsang.park@cnu.ac.kr; Corresponding author + voltage requirements, insertion loss, and energy efficiency; + 2 + +Section VI discusses implementation scope, platform com- +parisons, and limits; and Section VII concludes. 1 + Tk (∆ωk ) = . (9) + ∆ωk 2 + 1+ Γ + II. MODEL AND DESIGN FRAMEWORK + In a control–probe architecture, a nonnegative control- + signal amplitude I ≥ 0 shifts the ring resonance. Here I +Target mapping. Let x = (x1 , . . . , xK ) ∈ RK be an denotes a generic control amplitude: for optical-pump op- +arbitrary real-valued sequence (or vector). Directly gener- eration it maps to optical intensity, while for EO operation +ating exp(xn ) as a passive optical transmission is impos- it maps to electrical control level (e.g., voltage). Across +sible in general because exp(x) grows beyond unity while many physical mechanisms (optical pump via Kerr/XPM, +a passive transmission satisfies 0 < T ≤ 1 [25]. However, EO drive via Pockels effect, thermal, carrier tuning), the +for softmax, shift can be linearized on a working range [20, 26–30]: + + exn (0) + softmax(x)n = P xj , (1) ω0,k (I) = ω0,k + ηI, (10) + je + (0) + where ω0,k is the cold-cavity resonance and η is the control- +a common shift cancels: to-resonance sensitivity. In practice, the control channel + can be optical or electrical (optical pump, EO/Pockels + exn +c exn drive, thermal, or carrier tuning); a quantitative EO + P x +c = P x (∀c ∈ R). (2) feasibility example is given in the Discussion. With + je je + j j + (0) + ∆ω0,k ≡ ωL − ω0,k , the control-dependent detuning be- +Thus it suffices to generate comes + + + exn −m , m ≡ max xj , (3) ∆ωk (I) = ∆ω0,k − ηI. (11) + j + Define dimensionless parameters +since the global factor em cancels. + To ensure a nonnegative control-signal amplitude, de- +fine ∆ω0,k η + ak ≡ , b≡− . (12) + Γ Γ + Then Eq. (9) yields the control-to-probe transfer of a +un ≡ xn − m ≤ 0, L ≡ − min un = m − min xn ≥ 0, single ring, + n n + (4) +and map each scalar to a nonnegative control-signal am- 1 +plitude Tk (I) = . (13) + 1 + (ak + bI)2 + Physical meaning: ak is a static detuning in linewidth + In ≡ un + L ∈ [0, L]. (5) units (set by heater/carrier tuning/fabrication), and |b| + is the normalized sensitivity magnitude (linewidths of +Then + resonance shift per unit control-signal amplitude); the sign + convention is absorbed into the detuning expression. For + exn −m = eun = eIn −L . (6) “same-material/same-geometry” rings, b is often common, + while ak can be tuned per ring. +Hence the optical design task is to realize, for I ∈ [0, L], Sign convention. Simultaneously flipping (ak , b) 7→ + (−ak , −b) leaves Tk (I) unchanged, so we may take b > 0 + without loss of generality. + f (I) = eI−L ∈ [e−L , 1]. (7) Let N rings be cascaded in a serial add-drop topology: + Tk (I) denotes the add-to-drop transmission of ring k, and +Control–probe transfer. Consider a weak probe at the drop output of ring k feeds the add (input bus) port +fixed angular frequency ωL . For the kth ring, let ω0,k of ring k+1. Assuming the probe is sufficiently weak so +denote its resonance frequency and Γ > 0 its loaded half- the control channel dominates the resonance shift, the +width at half maximum (HWHM). Define the detuning normalized probe output is the product + + ∆ωk ≡ ωL − ω0,k . (8) (probe) + Pout (I) + N + Y N + Y 1 + y(I) ≡ = Tk (I) = . +Near resonance, the normalized Lorentzian transmission + (probe) + Pin 1 + (ak + bI)2 + k=1 k=1 +is modeled as [20, 21] (14) + 3 + + + (a) Electronic Preprocessing + Control In + Find max: Shift: Bias: + {xn } m = max(xn ) un = xn −m In = un +L + + + EO tuning + (b) N -MRR Cascade + + N stages + Probe + (fixed ωL ) + + + MRR MRR MRR MRR MRR + #1 #2 #3 #4 #5 + + + + + (c) Output + + ỹ(In ) ≈ exp(In − L) → exp(xn − m) PD + + + FIG. 1: Overview of the control–probe add-drop cascade N -MRR exponential block. (a) Electronic preprocessing + maps an arbitrary input sequence {xn } to a nonnegative control signal via m = maxn xn , un = xn − m, and +In = un + L with L = m − minn xn . (b) The control signal In induces resonance shifts in a cascade of N rings, while a + weak fixed-frequency probe propagates through the serial add-drop cascade (the drop output of each ring feeds the + next stage), experiencing multiplicative transmission. (c) After photodetection, the block implements + y(In ) ≈ exp(In − L) ≈ exp(xn − m), i.e., the normalized exponential used in softmax. + + +To focus on the shape of the approximation, we allow a +global scale factor C > 0: + E∞ ≡ sup ln ỹ(I) − (I − L) . (18) + I∈[0,L] + + ỹ(I) ≡ C y(I). (15) If E∞ ≤ εlog , then for all I ∈ [0, L], +In softmax, pn = CeIn −L / j CeIj −L , so C cancels + P +between numerator and denominator and is physically ỹ(I) ỹ(I) + e−εlog ≤ ≤ eεlog ⇒ − 1 ≤ eεlog − 1. (19) +inessential; nevertheless it is convenient for error analysis. f (I) f (I) +For a fixed (N, b, {ak }), the optimal C for the minimax + Thus achieving a prescribed worst-case relative error ε is +log-error in Eq. (18) can be written in closed form. Let + guaranteed by +g(I) ≡ ln y(I) − (I − L) on [0, L]. Then the minimax- +optimal shift is ln C ⋆ = −(maxI g(I)+minI g(I))/2, yield- +ing E∞ = (maxI g(I) − minI g(I))/2. E∞ ≤ εlog ≡ ln(1 + ε) ≈ ε. (20) + Taking logarithms, + Depth scaling. We derive depth-related constraints and + design rules for a prescribed approximation tolerance. + N + X Necessary slope condition. Differentiate Eq. (16): + ln 1 + (ak + bI)2 . + + ln ỹ(I) = ln C − (16) + k=1 + N + d X 2b(ak + bI) +The target ln f (I) = I − L is linear; hence exponential ln y(I) = − . (21) + dI 1 + (ak + bI)2 +approximation is equivalent to the log-linearization goal k=1 + + Since |2u/(1 + u2 )| ≤ 1 for all real u, + ln ỹ(I) ≈ I − L uniformly on I ∈ [0, L]. (17) + d + ln y(I) ≤ N |b|. (22) +Error metric. Define the worst-case log-error on [0, L]: dI + 4 + +The target ln f (I) = I − L has constant slope +1, so a with a minimax refinement. After choosing N , set +necessary condition to track it is b = min(bmax , 1/N ) and a = −1 − bL/2 as initializa- + tion, then refine (a, b) by a two-parameter minimax fit on + [0, L]. + N |b| ≳ 1. (23) A heuristic conservative screening bound N ≥ ⌈(L2 /4 + +Near-optimal parameterization. The full design prob- 1/(2b2 ))/ ln(1 + ε)⌉ (derived via the same local-expansion +lem can be written as a minimax fit in the log domain [31]: argument; see Supplementary Sec. S1) provides a quick + upper estimate but is not a rigorous guarantee. + + min sup |r(I)|, + a1 ,...,aN , ln C I∈[0,L] + III. NUMERICAL FITS AND VALIDATION + N + X (24) + ln 1 + (ak + bI)2 − (I − L). + + r(I) ≡ ln C − We validate the analytical framework with minimax + k=1 numerical fits and sampled robustness checks. Figure 2 +This objective is permutation-invariant in the ak ’s (ring shows the fitted approximation quality at L = 8: the +index k). In practice (and in numerical experiments top (linear) panel plots N = 1, 3, 5, 7 over I ∈ [0, 20], the +reported below), the optimizer frequently collapses to a middle (log) panel compares N = 5, 10, 20, 30 on I ∈ [0, 8], +permutation-symmetric solution and the bottom panel shows the pointwise relative error + with the characteristic Chebyshev equioscillation pattern. + We fit identical-detuning cascades (Eq. 25) on I ∈ [0, L] + a1 = · · · = aN ≡ a, (25) and compare several depths using a minimax criterion. + Table I makes the accuracy–depth trade-off explicit +reducing the design to two parameters (a, b) (plus C). at L = 8. A worked input-to-output example demon- +With Eq. (25), strating the mapping from an arbitrary input sequence + x = [−3.2, 1.2, 4.8, −0.9] through the cascade is provided + + 1 + N in Supplementary Sec. S2. The example shows that the + ỹ(I) = C y(I) = C . (26) N = 10 cascade keeps the worst-case relative error below + 1 + (a + bI)2 2.7% across all channels. +A robust initialization is obtained by placing the midpoint Empirical calibration. We calibrate the effective +of the interval on the Lorentzian half-maximum flank and logit range Leff from autoregressive Transformers (dis- +matching the slope: tilgpt2/gpt2) [1, 32–35] at context length 128, finding + Leff,0.999 ≈ 7–9 at the 50th–90th percentiles (Supplemen- + tary Sec. S2). A clipping threshold t∗ = −12 preserves + L p99 softmax accuracy below 0.1%. Full protocol details, + a+b ≈ −1, N b ≈ 1. (27) + 2 clipping-sweep tables/plots, and per-run statistics are +These two equations already yield a good design; a small provided in Supplementary Sec. S3. +(two-parameter) refinement can then enforce the desired A synthetic design-space map (Supplementary Table S3) +worst-case tolerance. shows that near L ≈ 8, moderate depth (N ≥ 10) reaches + Local expansion and depth scaling. A Taylor few-percent error, whereas L ≳ 12 requires deeper cas- +expansion of the log-domain residual around the flank- cades. All fits follow the same pipeline: minimize the +centered point I0 = L/2 (with a + bI0 = −1 and N b = 1) worst-case log-error on a uniform grid, initialize from the +shows that the quadratic term vanishes identically, leaving flank rules in Eq. (27), perform multi-start global search, +a leading cubic residual r(δ) ∼ δ 3 /(6N 2 ). Over I ∈ [0, L], and apply bounded local refinement; implementation de- +this implies E∞ ∼ L3 /N 2 , so that achieving a prescribed tails and scripts are provided in a public repository [36] + √ (commit: 585e695). +tolerance εlog requires N ∝ L3/2 / εlog , which explains +the scaling used in Eq. (28). The full derivation is provided +in Supplementary Sec. S0; an intuitive local-expansion +summary appears in Sec. S1. + Practical engineering estimate. Given L and a TABLE I: Depth comparison for L = 8 using fitted +target worst-case relative error ε, define εlog = ln(1 + ε). ỹ(I) = C[1 + (a + bI)2 ]−N (same fitting pipeline for all +A heuristic engineering estimate (not a rigorous bound) N ). +that matched our percent-level numerical designs is + N a b max rel. err. mean rel. err. + L3/2 + + 1 + N ≈ max , κ√ , (28) 5 −2.0789 0.21658 10.9% 6.43% + bmax εlog 10 −1.4588 0.10202 2.68% 1.65% + 20 −1.2135 0.05025 0.67% 0.42% +where bmax is the physically achievable sensitivity bound 30 −1.1392 0.03341 0.30% 0.19% +and κ ≃ 0.07 for the identical-detuning flank design + 5 + + TABLE II: Waveguide and ring parameters of the X-cut + TFLN micro-ring resonator. Electro-optic electrode + parameters are listed separately in Table III. + + Parameter Symbol Value Unit + Total TFLN thickness tTFLN 600 nm + Etch depth tetch 500 nm + Slab thickness tslab 100 nm + Waveguide width w 1.4 µm + Bend radius R 20 µm + Coupling gap g 100 nm + Circumference Lring 125.7 µm + Free spectral range FSR 8.29 nm + Effective index (TE0 ) neff 1.903 — + Group index (TE0 ) ng 2.24 — + Extraordinary index ne 2.138 — + + + + IV. TFLN SINGLE-RING DEVICE DESIGN AND + FDTD VALIDATION + + A. Waveguide and ring geometry + + + The device is based on an X-cut thin-film lithium nio- + bate (LiNbO3 ) on insulator wafer with a 600 nm-thick + LiNbO3 film on SiO2 . A 500 nm-deep rib etch defines + a 1.4 µm-wide single-mode waveguide with a 100 nm un- + etched slab (Fig. 3). Lumerical MODE simulations yield + neff = 1.903 and ng = 2.24 at λ = 1550 nm for the funda- + mental TE0 mode. + The ring resonator (R = 20 µm, Lring = 125.7 µm) is + configured as an add-drop resonator with 100 nm coupling + gaps (Fig. 4). The FDTD-measured free spectral range + is FSR = 8.29 nm (ng ≈ 2.30), slightly above the MODE + value due to bend-induced dispersion. + + + + +FIG. 2: Minimax cascade fits at L = 8. (a) Linear scale: + shallow cascades (N = 1, 3, 5, 7) over I ∈ [0, 20]. The +target eI−L (black) is progressively better matched as N + increases. (b) Log scale: depth comparison + (N = 5, 10, 20, 30) on I ∈ [0, 8]. Inset zooms into + I ∈ [6, 8] showing convergence. (c) Pointwise relative + error showing the Chebyshev equioscillation pattern + characteristic of minimax optimality. + FIG. 3: Cross-section of the X-cut TFLN rib waveguide + on a SiO2 substrate. The 600 nm LiNbO3 film is etched + 500 nm to form a 1.4 µm-wide single-mode rib waveguide. + Lateral signal (S) and ground (G) electrode positions are + indicated; electrode design details are discussed in + Sec. IV D. + 6 + + Table II summarizes the waveguide and ring parame- +ters. + + + B. 3D FDTD Methodology + + The ring resonator response is simulated using Lumeri- +cal 3D FDTD with conformal variant 1 meshing. A broad- +band TE0 mode source (1530 nm to 1570 nm) is injected +into the input bus waveguide, and through- and drop-port +spectra are recorded. A “z-refined 3-fix” meshing strat- +egy ensures convergence in the thin-film geometry [37]; +detailed simulation setup is provided in Supplementary +Sec. S4 (Table S6). + + + FIG. 5: Simulated through-port (blue) and drop-port + (red) transmission spectra of the single add-drop + micro-ring resonator from 3D FDTD. Top: logarithmic + scale; bottom: linear scale. Five resonances are visible + with FSR ≈ 8.29 nm. + + + + 15,500, Dmax = 0.360); using the five-resonance mean + would increase required voltages by ∼24% (see Table IV + caption). + The simulation time of 50 ps exceeds the loaded pho- + ton lifetime τL = QL λ0 /(2πc) ≈ 12.7 ps by ∼4×, but + the intrinsic lifetime τi ≈ 32 ps is comparable, so the ex- + tracted Qi may be slightly conservative. An independent + eigenmode (FDE) analysis of the same cross-section at + R = 20 µm—using a 300 × 300 mesh (∆y ≈ 10 nm, 5× + FIG. 4: Top view of the single add-drop micro-ring finer than the FDTD vertical grid)—yields Qrad+leak = + resonator used in the 3D FDTD simulation. The ring 2.4 × 107 ; including bulk LiNbO3 absorption (Γ = 0.89) + waveguide (R = 20 µm, w = 1.4 µm) is evanescently gives a theoretical Qi > 107 [37–42], confirming that + coupled to input and drop bus waveguides through the gap between the numerical Qi and published val- + 100 nm gaps at coupling points CP1 and CP2. ues (> 106 ) originates from mesh discretization (Sup- + plementary S4.5, Table S8). In the CMT framework, + Dmax = [2κ/(2κ+γ)]2 increases as Qi rises; at the present + coupling gap, increasing Qi to 106 would raise Dmax from + 0.36 to ∼0.95 and QL from 15,500 to ∼25,200. + C. Single-Ring Add-Drop Results + Figure 6(a) shows a Lorentzian fit to the best drop- + Figure 5 shows the through- and drop-port spectra from port resonance at λ = 1566 nm, validating the cascade +3D FDTD. Five resonances are resolved across 1530 nm model (Eq. 9). Figure 6(b) demonstrates that cascading +to 1570 nm with FSR = 8.29 nm (ng ≈ 2.30). N copies of this FDTD-extracted Lorentzian reproduces + the target exponential eI−L with increasing fidelity as N + Lorentzian fitting of the drop-port peaks yields QL = + grows. +10,300–15,500, with the best resonance at λ = 1566 nm +reaching QL = 15,500 (FWHM = 101 pm, Dmax = 0.360, To validate the cascade prediction directly, a five- +−4.4 dB). The through-port extinction ratio is 1.6 dB to ring cascade 3D FDTD simulation was performed us- +2.6 dB, and the five-resonance mean is QL = 12,500 ± ing Tidy3D [43]; the full simulation notebook is publicly +1,800 (Dmax = 0.29–0.36). CMT √ analysis of the best available [43]. The |E|2 field at λ = 1549 nm [Fig. 6(d)] +resonance gives Qi = QL /(1 − Dmax ) = 15,500/0.400 ≈ confirms resonant excitation across all five rings. Map- +38,800, confirming that the 500 nm etch provides sufficient ping the drop-port spectrum onto the control variable I +confinement and that the 100 nm gap places the ring yields 11 data points within the AEF operating range +in the coupling-limited regime. The cascade analysis [Fig. 6(e, f)], with the FDTD transmission closely tracking +below adopts the best-case FDTD calibration (QL = the N = 5 theoretical curve near I ≈ L = 8. + 7 + + + + +FIG. 6: FDTD-based AEF validation. (a) Lorentzian fit to the drop-port resonance at λ = 1566 nm from 3D FDTD + (Lumerical) (QL = 15,500, Dmax = 0.360, bV = 0.180 V−1 ). (b) Five-ring cascade drop-port spectrum near + λ0 ≈ 1550 nm with Lorentzian5 fit (red curve), confirming the expected T 5 line shape. (c) Five-ring cascade MRR +layout with diagonal zigzag bus waveguides. (d) |E|2 field profile at λ = 1549 nm from a five-ring cascade 3D FDTD + simulation (Tidy3D [43]). (e, f ) AEF validation of the five-ring cascade on log (e) and linear (f) scales with + 11 spectral FDTD data points. + 8 + + D. X-cut electrode design and EO parameters TABLE III: Electro-optic electrode parameters for the + X-cut TFLN micro-ring with lateral S–G arc electrodes. + We employ lateral signal–ground (S–G) arc electrodes +on the slab surface alongside the ring waveguide (Fig. 7). Parameter Symbol Value Unit +In the X-cut orientation, the crystal Z-axis is at 45◦ from Crystal orientation — X-cut — +the horizontal in the substrate plane, giving a lateral- EO coefficient r33 30.9 pm V−1 +field projection proportional to cos(θ − 45◦ ) at azimuthal EO fill factor fEO 1/π ≈ 0.318 — +angle θ. The cos(θ − 45◦ ) = 0 boundaries at θ = 135◦ EO overlap factor ΓEO 0.7 — +and 315◦ naturally separate the coupling regions from Electrode gap gel 5 µm + Effective electrode distance deff 2.5 µm +the electrode regions. Each ring carries a full semicir- +cular arc electrode on the side opposite to its coupling +points, engaging the large r33 = 30.9 pm V−1 Pockels co- +efficient [37, 38]. The effective EO fill factor follows from ized voltage sensitivity is (Supplementary Sec. S4; here +integrating | cos(θ − 45◦ )| over the semicircle: dλ/dV = 28.5 pm/V is the straight-section value and + 1 fEO accounts for partial electrode coverage of the ring + fEO = ≈ 0.318 (29) circumference): + π +(see Supplementary Sec. S4 for derivation). The electrode 2 Q (dλ/dV ) +gap is gel = 5 µm (deff ≈ 2.5 µm), and the electro-optic bV = fEO ≈ 0.182 V−1 (30) +overlap integral is ΓEO = 0.7. Table III lists the electrode λ0 +parameters. + at QL = 15,500. This estimate relies on a first-order + electrostatic model (deff ≈ 2.5 µm, ΓEO = 0.7); a ±30% + variation in bV would shift the cascade depth by one to + two rings at constant εmax (Table IV), leaving the quali- + tative design conclusions unchanged. With the cascade + framework of Sec. II (Eqs. 14–18), the N -ring drop-port + transmission ỹ(I) = C [1 + (a + bI)2 ]−N approximates + eI−L over I ∈ [0, L], with (a, b) optimized by minimax + fitting for each N . + Table IV presents the optimization results for the stan- + dard dynamic range L = 8 (e8 ≈ 2981, 34.7 dB). + + TABLE IV: Cascade optimization results for L = 8. The + bias voltage Vbias = |a|/bV sets the DC offset, and + Vctrl = bL/bV is the maximum control voltage at I = L. + Voltages computed with bV = 0.182 V−1 (X-cut arc + electrode, FDTD-calibrated best resonance QL = 15,500, + ng = 2.30). The mean FDTD quality factor across five +FIG. 7: Illustrative two-ring cascade layout showing the resonances is QL = 12,500 ± 1,800; using the mean would +lateral S–G arc electrode placement on X-cut TFLN (the increase voltages by ∼24%. +cascade design extends to N rings; this two-ring example + clarifies the electrode geometry). The crystal Z-axis is N a b E∞ εmax (%) Vbias (V) Vctrl (V) + oriented at 45◦ from the horizontal in the substrate 5 −2.0789 0.21658 0.1035 10.91 11.4 9.5 +plane. The cos(θ − 45◦ ) = 0 boundaries at θ = 135◦ and 10 −1.4588 0.10202 0.0265 2.68 8.0 4.5 + 315◦ naturally separate the bus-waveguide coupling 12 −1.3731 0.08450 0.0184 1.86 7.5 3.7 +regions from the electrode semicircles: each ring carries a 20 −1.2136 0.05025 0.0067 0.67 6.7 2.2 + 25 −1.1685 0.04013 0.0043 0.43 6.4 1.8 +full semicircular arc electrode on the side opposite to its + 30 −1.141 0.03340 0.0030 0.30 6.3 1.5 + coupling points. The resulting effective EO fill factor is 32 −1.1301 0.03131 0.0026 0.26 6.2 1.4 + fEO = 1/π ≈ 0.318. + a The complete cascade optimization results for all N values are + + listed in Supplementary Table S7. + + +E. FDTD-Calibrated bV and Cascade Optimization The approximation quality across different cascade + depths is shown in Fig. 2 (Sec. III). Key thresholds (e.g., + From the device parameters in Tables II and III and ε < 2% at N ≥ 12, ε < 1% at N ≥ 17) and the complete +the FDTD-calibrated ng ≈ 2.30, the effective normal- optimization results are listed in Supplementary Sec. S4. + 9 + + V. PHYSICAL FEASIBILITY TABLE V: Two-regime power budget for the MRR + cascade. Pout assumes per-channel input + Having established the cascade approximation theory Pin,ch = 100 µW (from a shared Pin,tot = 1 mW CW +(Sec. II) and the FDTD-calibrated device parameters laser split across M = 10 parallel channels via a 1×M +(Sec. IV), we now assess the physical feasibility of the splitter, or equivalently multiplexed as d WDM channels +proposed architecture in terms of voltage requirements, sharing a single cascade) and accounts only for the ideal + N +insertion loss, and energy efficiency. on-resonance cascade transmission Dmax (upper bound); + additional inter-ring coupling loss (ηcoupling ≈ 0.9 per + stage, ∼0.46 dB/stage) and off-resonance propagation + A. Electro-optic voltage requirements loss (0.08–0.25 dB/stage) are analyzed separately in + Sec. V C. + For the primary target of ε < 2% (N = 12), minimax + N +optimization gives a = −1.373, b = 0.0845. With the Dmax N Dmax (dB) Pout εmax +FDTD-calibrated QL = 15,500 (bV = 0.182 V−1 ), the 0.36 3 0.0467 −13.3 4.67 µW ∼15% + I +required voltages are (FDTD) 0.36 5 0.00605 −22.2 0.61 µW 10.9% + 0.36 7 7.84 × 10−4 −31.1 78 nW ∼5% + |a| 1.373 0.95 10 0.599 −2.2 59.9 µW 2.68% + Vbias = = = 7.5 V, (31) II + (high-Q) 0.95 20 0.358 −4.5 35.8 µW 0.67% + bV 0.182 + 0.95 30 0.215 −6.7 21.5 µW ∼0.30% + bL 0.0845 × 8 + Vctrl,max = = = 3.7 V. (32) Regime I: FDTD-characterized (Qi = 38,800). Regime II: + bV 0.182 fabricated high-Q (Qi > 106 ). Pout scales linearly with Pin,ch . + +Since bV ∝ Q, voltage scales inversely with quality factor: + + bL bL λ0 independent evidence that intrinsic quality factors in + Vctrl = = . (33) the projected range are physically achievable in TFLN— + bV 2Q |dλ0 /dV | + albeit with wider waveguides and larger ring radii than the +CMOS-compatible control voltages (Vctrl < 3.3 V) are present design. Transferring comparable sidewall quality +achievable at N ≥ 14 with QL = 15,500; at the design to our geometry (R = 20 µm, W = 1.4 µm) is an open +point N = 30 (εmax = 0.30%), Vctrl = 1.47 V. fabrication challenge; the projections should be read as + design targets contingent on achieving it. + The total insertion loss comprises on-resonance + N + B. Power budget: two-regime analysis cascade transmission Dmax , inter-ring coupling loss + (∼0.46 dB/stage for the present diagonal-bus layout), + The on-resonance cascade transmission DmaxN + is the off-resonance propagation loss (0.08–0.25 dB/stage), and +dominant contribution to total insertion loss. Table V fiber-to-chip coupling (1.5–3.0 dB). For the fabricated +presents two regimes: the FDTD-characterized regime high-Q regime (N = 30), the total ranges from ∼13 dB +(Dmax = 0.36) and the fabricated high-Q regime (Dmax = (optimized layout) to ∼24 dB (current geometry); see +0.95, achievable with Qi > 106 and gap-optimized cou- Supplementary Sec. S6 for detailed scenarios. +pling). + In the FDTD-characterized regime, Dmax = 0.36 limits +practical cascades to N ≤ 5: at N = 5 the output is D. Energy comparison +0.61 µW (−22.2 dB) with ε = 10.9%, suited for proof- +of-concept validation. In the fabricated high-Q regime For N = 30 X-cut TFLN micro-ring resonators in the +(Dmax ≥ 0.95), deep cascades become practical: N = 30 fabricated high-Q regime (QL ≈ 25,200 at Qi = 106 ; Sup- +yields Pout = 21.5 µW (−6.7 dB) with εmax ≈ 0.30%. plementary Sec. S5), the three energy components are EO +The transition to fabricated high-Q devices is therefore tuning (EEO = 0.22 pJ), amortized laser (Elaser = 0.07 pJ, +critical for achieving both high accuracy and sufficient shared across M = 10 channels), and photodetector +output power. (EPD = 0.50 pJ), yielding Ephotonic = 0.79 pJ (deriva- + tions in Supplementary Sec. S7). Including thermal stabi- + lization for N = 30 rings (0.15–0.60 pJ; Supplementary + C. Feasibility outlook Sec. S7), the total rises to 0.94–1.39 pJ. + Table S12 compares the photonic cascade with digital + Published TFLN micro-ring resonators achieve Qi ≥ implementations. Including thermal stabilization (0.94– +106 –108 using optimized fabrication [39–42]. At Qi = 1.39 pJ), the advantage over INT8 (2.3 pJ) is 1.7–2.4×, +106 with the present coupling geometry, CMT predicts while operating at 10 GHz bandwidth and 58× lower than +Dmax ≈ 0.95 and QL ≈ 25,200 (Supplementary Sec. S5, digital FP32 (46 pJ). At fabricated Q ≥ 30,000, EEO +Tables S4–S7), enabling deep cascades (N ≤ 30) with drops to 0.16 pJ and Etotal ≈ 0.73 pJ (excluding thermal; +sub-percent error. The literature values provide strong Supplementary Table S11), recovering a 3.2× advantage + 10 + + TABLE VI: Energy per exponential operation: with a distinct FSR order of the same ring set, traverse a + single-channel comparison. single N -ring cascade simultaneously (Fig. 8). Because + each channel λj sees its own Lorentzian detuning set by + Implementation E/op (pJ) Bandwidth Notes an independent control QN + voltage Vj , the cascade output + Digital FP32 (Taylor) ∼46 1 GHz 10 FP MACsper channel is ỹj = C k=1 Tdrop,k (λj , Vj ) ≈ eVj , and all + Digital INT8 (Taylor) ∼2.3 1 GHz 10 INT MACsd exponentials are computed in parallel on the same phys- + Photonic MRR (N = 30) 0.94–1.39 10 GHz Analog† ical waveguide. Compared with a 1×M power-splitter + † 0.79 pJ excluding thermal; 0.94–1.39 pJ including thermal. architecture that replicates the cascade for each channel, + Self-consistent with fabricated high-Q regime (QL = 25,200); see the WDM approach reduces the total ring count from + Supplementary Sec. S7. N × d to N (a factor-d saving) and eliminates the splitter + insertion loss (10 log10 d dB). At the output, a WDM + demultiplexer or wavelength-selective photodetector array +over INT8. Since EEO ∝ 1/Q2 , improving Q beyond separates the channels for electrical readout. Figure 8 +∼30,000 yields diminishing energy returns but continues shows a representative chip layout for N = 5 cascade +to relax CMOS driver voltage requirements. stages and d = 8 WDM channels, where alternating U- + turn bus connections route the drop-port output of each + stage into the input bus of the next. + VI. DISCUSSION Why cascade helps. A single Lorentzian in I is too + rigid to mimic the log-linear target over a wide interval. + Practical design procedure. For a given input se- Cascading turns the transfer into a product; taking a +quence x = (x1 , . . . , xK ), the design proceeds as follows: logarithm gives a sum of smooth terms, and the approx- + imation improves as N increases. The slope constraint + 1. Compute m = maxn xn , un = xn − m, and L = N |b| ≳ 1 is an immediate feasibility check. + − minn un . Global softmax normalization via WDM feed- + 2. Map to nonnegative control-signal amplitudes: In = back. The WDM-parallel architecture (Fig. 8) integrates + un + L ∈ [0, L]. naturally with a closed-loop normalization scheme to com- + plete the full softmax function. After the N -stage cascade, + 3. Choose tolerance ε and set εlog = ln(1 + ε). a WDM demultiplexer (e.g., arrayed-waveguide grating or + ring-filter bank) routes each channel λj to a dedicated pho- + 4. Select a physically feasible bmax and estimate N todetector, producing photocurrents Iλj ∝ ỹj ≈ C Pin eVj . + using Eq. (28). The d photocurrents are summed electrically: + 5. Initialize b = min(bmax , 1/N ) and a = −1 − bL/2, d d + then refine (a, b) by a two-parameter minimax fit if + X X + S= Iλj ∝ C Pin eVj . (35) + required. j=1 j=1 + + 6. The optical block yields ỹ(In ) ≈ exn −m , and soft- A proportional–integral (PI) controller compares S with + max weights follow as a fixed reference Sref and adjusts the shared WDM laser + power Pin so that S → Sref [44, 45]. Because all d channels + share the same probe source, scaling Pin multiplies every + ỹ(In ) + pn = P . (34) ỹj by the same factor; upon convergence + j ỹ(Ij ) + Iλj eVj + pj = = Pd = softmax(V )j , (36) + Scope and limits. The approximation is for a fi- Sref Vk + k=1 e +nite interval I ∈ [0, L], where L is determined by the +input batch via Eq. (4). In practice, one designs for a realizing the complete softmax with a single feedback loop +worst-case L expected in operation (or retunes a and and no per-channel normalization circuitry. Compared +rescales the control signal to adapt L). Noise, insertion with the replicated-cascade approach (one AEF block per +loss, and control-induced parasitics limit accuracy and channel), WDM feedback offers two additional benefits: +dynamic range; we treat these effects as platform-specific (i) the splitter-induced power imbalance that would bias +margins. Detailed non-ideality assumptions, parameter the Iλj ratios is absent, since all channels traverse the +distributions, and robustness statistics are reported in same optical path; and (ii) a single laser control point +Supplementary Sec. S8. With K channels in parallel, replaces d independent probe adjustments. Design de- +one can form softmax by summing channel powers and tails and stability analysis of the PI loop are provided in +applying a shared reciprocal scale factor, depending on Supplementary Sec. S9. +the chosen mixed-signal normalization scheme. Beyond ring-resonator AEF implementations, the same + WDM parallelism. A particularly hardware-efficient cascade principle can be extended to other cavity-based +realization exploits wavelength-division multiplexing photonic platforms, such as serial 1D photonic-crystal cav- +(WDM): d probe wavelengths λ1 , . . . , λd , each resonant ities and other cascaded resonant architectures [21, 46]. + 11 + +What these platforms share is transfer-function shaping TABLE VII: Summary of evidence levels for key claims. +through cascaded resonances; loss, tuning range, fabrica- +tion tolerance, and calibration overhead remain platform- Claim Evidence Sec. +dependent. Cascade → exp. approx. Analytic II + The insertion loss budget (Sec. V C) and electro-optic Depth scaling Analytic + num. II, III +voltage requirements (Sec. V A) suggest that the cas- QL , Dmax , bV 3-D FDTD IV +cade architecture is feasible under optimized coupling 5-ring line shape 3-D FDTD IV +and layout conditions. Using monolithic TFLN microring N ≤ 30 deep cascade CMT proj.∗ V + Energy < 1 pJ Estimate V +data from Bahadori et al. [47] (Q ≈ 5432, dλ0 /dV ≈ + Full softmax (WDM + feedback) Conceptual + layout VI +9–20 pm/V), the normalized sensitivity bV ≃ 0.063– + ∗ Based on published Q +0.14 V−1 , within the range required by the cascade design. i ≥ 10 + 6 values [39, 42] and CMT coupling + + model. +Crystal orientation and electrode design. The X- +cut TFLN platform was chosen for several reasons. First, +X-cut is the prevailing industry standard for integrated tified in the Monte Carlo robustness analysis (Supple- +TFLN modulators, with well-established fabrication pro- mentary Sec. S8). Monte Carlo simulations (Supplemen- +cesses and commercial wafer availability [37, 38]. Second, tary Sec. S8) show that under nominal non-ideality levels +the TE0 mode—which is strongly confined in the rib (σa = 0.020, σb,rel = 0.020), a single-point calibration of +waveguide geometry—can engage the large r33 coefficient C per chip keeps the median softmax KL divergence below +via lateral electric fields aligned with the crystal Z-axis. 2.2 × 10−4 , with 95th-percentile max probability error +In contrast, Z-cut geometry with TE polarization can only under 0.32%. Even under stress conditions (σa = 0.032), +access the smaller r13 coefficient (∼ 10 pm/V), resulting 95th-percentile errors remain below 0.42%, demonstrat- +in significantly lower electro-optic efficiency. The arc elec- ing that the identical-detuning design is robust to realis- +trode design (Sec. IV D) addresses the phase-cancellation tic fabrication variations provided a per-chip calibration +problem inherent to X-cut circular rings [47] by orienting step is performed. Conversely, if coupling gaps are in- +the crystal Z-axis at 45◦ from the horizontal in the sub- tentionally varied across rings, the per-ring parameters +strate plane. This rotation places the cos(θ − 45◦ ) = 0 (ak , bk ) become independent degrees of freedom. A Taylor- +boundaries at θ = 135◦ and 315◦ , naturally separating the expansion analysis shows that K non-identical rings can +bus-waveguide coupling regions from the electrode regions. cancel curvature + P terms up to order 2K in the Taylor series +Each ring carries a full semicircular arc electrode on the of g(I) = k ln Tk , one order higher than K identical +side opposite to its coupling points, yielding an effective rings, so that fewer rings suffice for a given error target. +fill factor fEO = 1/π ≈ 0.318. While this reduces the +round-trip EO efficiency compared to a hypothetical full- +circumference design, it preserves the compact footprint +of a circular ring resonator. The cascade performance +can be further improved beyond the R = 20 µm circular- +ring design presented here. Increasing the ring radius +reduces bending loss and raises the intrinsic quality factor +Qi , which directly increases bV (∝ Q) and lowers the +required control voltage. Alternatively, adopting a race- +track geometry with extended straight coupling sections +strengthens the bus–ring coupling, pushing the drop-port +maximum Dmax closer to critical coupling and improving +the per-stage transfer efficiency. Either approach—or their +combination—would yield higher bV and Dmax , enabling +lower N or tighter approximation accuracy at reduced +operating voltages. +Fabrication considerations. The X-cut TFLN rib +waveguide (600 nm total thickness, 500 nm etch, w = +1.4 µm) follows established fabrication processes for com- +mercial TFLN wafers on SiO2 [37, 38]. The lateral signal– +ground (SG) electrode configuration is fabricated in a +single metal layer, which is standard in TFLN foundry +processes. The primary fabrication challenge for the +cascade architecture is maintaining uniform coupling +gaps (g = 100 nm) across N rings to ensure identi- +cal Lorentzian transfer functions. Post-fabrication trim- +ming via UV exposure or localized thermal oxidation can +compensate residual detuning variations [30], as quan- + 12 + + + + + Softmax Full Chip Layout – N = 5 × d = 8 (TFLN) + d = 8 WDM channels + + + Vλ1 Vλ2 Vλ3 Vλ4 Vλ5 Vλ6 Vλ7 Vλ8 + + WDM + λ1−λ8 n=1 + Pin + + + n=2 + N = 5 + cascade + n=3 stages + + + + + n=4 + + + n=5 + + + + + WDM Demux (AWG / ring filter) + + Sref + PD1 PD2 PD3 PD4 PD5 PD6 PD7 PD8 + Iλ + j S e + Σ − PI + p1 p2 p3 p4 p5 p6 p7 p8 + + + + + Feedback: adjust Pin + Iλj + Output: pj = = softmax(V )j + Sref + +FIG. 8: WDM-parallel MRR-AEF system with closed-loop softmax normalization (N = 5 cascade stages, d = 8 WDM + channels) on X-cut TFLN. A single WDM source (λ1 –λ8 ) enters the top input bus waveguide; each stage applies a + Lorentzian drop-port transfer, and alternating U-turn connections route the drop-port output into the next stage’s +input bus. Per-channel EO bias voltages (Vλ1 –Vλ8 ) independently tune each column of rings. The final drop output + passes through a WDM demultiplexer (AWG / ring filter) and is detected by a PD array, producing per-channel + photocurrents Iλj ∝ eVj . The photocurrents are summed (Σ) and compared with a reference Sref ; a PI controller + adjusts the shared laser power Pin until S = Sref , at which point each PD output directly yields + pj = Iλj /Sref = softmax(V )j (Eq. 36). + 13 + + VII. CONCLUSION Dmax ≥ 0.95) are realized in the cascade geometry, deeper + cascades (N ≈ 20–30) would reach sub-percent approx- + We have presented a cascaded micro-ring resonator ar- imation error with an estimated per-operation energy +chitecture that approximates the exponential function of 0.79–1.39 pJ, which is 1.7–2.4× lower than an INT8 +exn −m on a finite interval [0, L] using multiplicative MAC at the 7 nm node. Monte Carlo analysis shows that +Lorentzian transfer functions. Increasing the cascade the identical-detuning design tolerates realistic fabrica- +depth N systematically reduces the worst-case relative tion variations (σa = 0.020, σb,rel = 0.020) with a single +error, and an identical-detuning design initialized by flank per-chip calibration, keeping the 95th-percentile softmax +and slope matching provides a practical two-parameter probability error below 0.32%. +design. + Three-dimensional FDTD simulations of a single X-cut The formulation is not restricted to electro-optic tuning: +TFLN add-drop ring (R = 20 µm, g = 100 nm) yield it requires only a controllable detuning coordinate with lo- +QL = 10,300–15,500 and Dmax ≈ 0.36, calibrating the cal linearization, so both Pockels and optical (Kerr/XPM) +cascade transfer model. A five-ring cascade 3D FDTD mechanisms are compatible [37, 38, 47, 48]. We demon- +simulation directly validates the multi-ring framework: strate a photonic exponential block and present a WDM- +all five rings exhibit resonant excitation, and mapping parallel chip architecture (Fig. 8) in which d wavelength +the drop-port spectrum onto the dimensionless control channels share a single N -ring cascade, reducing the total +variable reproduces the theoretical N = 5 curve with ring count by a factor of d and eliminating power-splitter +∼11% integrated relative-area error over the upper op- loss. Combined with a single-loop PI feedback that adjusts +erating range (I ≥ 5.8), providing the first multi-ring the shared WDM laser power, the architecture realizes the +confirmation of the cascade exponential approximation. complete softmax function—exponentiation, summation, +At the present FDTD-characterized quality factor, practi- and normalization—without per-channel normalization +cal cascades are limited to N = 5–7 (ε ≲ 11%). If high-Q circuitry. Max-finding and digital interfacing remain open +TFLN resonators reported in the literature (Qi ≥ 106 , for future experimental validation. + + + + + [1] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Shengyuan Lu, Qihang Zhang, Lingyan He, C. A. A. + Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Franken, Keith Powell, Hana Warner, Daniel Assumpcao, + and Illia Polosukhin. Attention is all you need. In Dylan Renaud, Ying Wang, et al. Integrated lithium + Advances in Neural Information Processing Systems 30 niobate photonic computing circuit based on efficient and + (NeurIPS 2017), pages 5998–6008, 2017. high-speed electro-optic conversion. Nature Communica- + [2] Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, tions, 16:8178, 2025. + and Christopher Ré. FlashAttention: Fast and memory- [11] Priyabrata Dash, Anxiao Jiang, and Dharanidhar Dang. + efficient exact attention with IO-awareness. In Advances SOFTONIC: A photonic design approach to softmax + in Neural Information Processing Systems 35 (NeurIPS activation for high-speed fully analog AI acceleration. + 2022), pages 16344–16359, 2022. In Proceedings of the Great Lakes Symposium on VLSI + [3] Neil Savage. Light could lower AI’s appetite for power. (GLSVLSI ’25), pages 118–125, 2025. + Nature Nanotechnology, 21:6–8, 2026. [12] Ziyu Zhan, Hao Wang, Qiang Liu, and Xing Fu. Opto- + [4] Yichen Shen et al. Deep learning with coherent nanopho- electronic nonlinear softmax operator based on diffractive + tonic circuits. Nature Photonics, 11(7):441–446, 2017. neural networks. Optics Express, 32(15):26458–26469, + [5] Johannes Feldmann et al. Parallel convolutional process- 2024. + ing using an integrated photonic tensor core. Nature, [13] Ye Tian, Shuiying Xiang, Xingxing Guo, Yahui Zhang, + 589(7840):52–58, 2021. Jiashang Xu, Shangxuan Shi, Haowen Zhao, Yizhi Wang, + [6] Nicholas C. Harris et al. Linear programmable nanopho- Xinran Niu, Wenzhuo Liu, and Yue Hao. Photonic trans- + tonic processors. Optica, 5(12):1623–1631, 2018. former chip: interference is all you need. PhotoniX, 6:45, + [7] Bowei Dong, Samarth Aggarwal, Wen Zhou, Utku Emre 2025. + Ali, Nikolaos Farmakidis, June Sang Lee, Yuhan He, Xuan [14] Jacob R. Stevens, Rangharajan Venkatesan, Steve Dai, + Li, Dim-Lee Kwong, C. D. Wright, Wolfram H. P. Pernice, Brucek Khailany, and Anand Raghunathan. Softermax: + and H. Bhaskaran. Higher-dimensional processing using Hardware/software co-design of an efficient softmax for + a photonic tensor core with continuous-time data. Nature transformers. In Proceedings of the 58th ACM/IEEE + Photonics, 17(12):1080–1088, 2023. Design Automation Conference (DAC), pages 469–474, + [8] Sudip Shekhar, Wim Bogaerts, Lukas Chrostowski, 2021. + John E. Bowers, Michael Hochberg, Richard Soref, and [15] Nazim Altar Koca, Anh Tuan Do, and Chip-Hong + Bhavin J. Shastri. Roadmapping the next generation of Chang. Hardware-efficient softmax approximation for + silicon photonics. Nature Communications, 15:751, 2024. self-attention networks. In Proceedings of the IEEE Inter- + [9] Mario Miscuglio and Volker J. Sorger. Photonic tensor national Symposium on Circuits and Systems (ISCAS), + cores for machine learning. Applied Physics Reviews, pages 1–5, 2023. + 7(3):031404, 2020. [16] Wenxun Wang, Shuchang Zhou, Wenyu Sun, Peiqin Sun, +[10] Yaowen Hu, Yunxiang Song, Xinrui Zhu, Xiangwen Guo, and Yongpan Liu. SOLE: Hardware-software co-design + 14 + + of softmax and layernorm for efficient transformer infer- 2025. accessed 2026-02-21. + ence. In Proceedings of the IEEE/ACM International [35] Jane Austen. Pride and prejudice. Project Gutenberg + Conference on Computer-Aided Design (ICCAD), pages eBook No. 1342, 2025. accessed 2026-02-21. + 1–9, 2023. [36] Hyoseok Park. MRR-AEF: reproducible MRR depth- +[17] Yuan Zhang, Yonggang Zhang, Lele Peng, Lianghua Quan, sweep fitting and supplementary validation scripts. + Shubin Zheng, Zhonghai Lu, and Hui Chen. Base-2 soft- GitHub repository, 2025. commit 585e695, accessed 2026- + max function: Suitability for training and efficient hard- 02-21. + ware implementation. IEEE Transactions on Circuits and [37] Di Zhu et al. Integrated photonics on thin-film lithium + Systems I: Regular Papers, 69(9):3605–3618, 2022. niobate. Advances in Optics and Photonics, 13(2):242–352, +[18] Zhengyu Mei, Hongxi Dong, Yuxuan Wang, and Hongbing 2021. + Pan. TEA-S: A tiny and efficient architecture for PLAC- [38] Yaowen Hu, Di Zhu, Shengyuan Lu, Xinrui Zhu, Yunxiang + based softmax in transformers. IEEE Transactions on Song, Dylan Renaud, Daniel Assumpcao, Rebecca Cheng, + Circuits and Systems II: Express Briefs, 70:3594–3598, CJ Xin, Matthew Yeh, Hana Warner, Xiangwen Guo, + 2023. Amirhassan Shams-Ansari, David Barton, Neil Sinclair, +[19] Ke Chen, Yue Gao, Haroon Waris, Weiqiang Liu, and and Marko Loncar. Integrated electro-optics on thin-film + Fabrizio Lombardi. Approximate softmax functions for lithium niobate. Nature Reviews Physics, 2025. + energy-efficient deep neural networks. IEEE Transactions [39] Mian Zhang, Cheng Wang, Rebecca Cheng, Amirhassan + on Very Large Scale Integration (VLSI) Systems, 31:4–16, Shams-Ansari, and Marko Lončar. Monolithic ultra-high- + 2023. Q lithium niobate microring resonator. Optica, 4(12):1536– +[20] Wim Bogaerts et al. Silicon microring resonators. Laser 1537, 2017. + & Photonics Reviews, 6(1):47–73, 2012. [40] Rongjin Zhuang, Jinze He, Yifan Qi, and Yang Li. High-Q +[21] John E. Heebner, Robert W. Boyd, and Q.-Han thin-film lithium niobate microrings fabricated with wet + Park. Scissor solitons and other propagation effects in etching. Adv. Mater., 35(3):2208113, 2023. + microresonator-modified waveguides. Journal of the Opti- [41] Xinrui Zhu, Yaowen Hu, Shengyuan Lu, Hana K. + cal Society of America B, 19(4):722–731, 2002. Warner, Xudong Li, Yunxiang Song, Letı́cia S. Mag- +[22] Jiahui Wang, Sean P. Rodrigues, Ercan M. Dede, and alhães, Amirhassan Shams-Ansari, Neil Sinclair, and + Shanhui Fan. Microring-based programmable coherent Marko Lončar. Twenty-nine million intrinsic Q-factor + optical neural networks. Optics Express, 31(12):18871, monolithic microresonators on thin-film lithium niobate. + 2023. Photon. Res., 12(8):A63–A68, 2024. +[23] Pengxing Guo, Niujie Zhou, Weigang Hou, and Lei Guo. [42] Renhong Gao, Ni Yao, Jianglin Guan, Li Deng, Jintian + StarLight: a photonic neural network accelerator featur- Lin, Min Wang, Lingling Qiao, Wei Fang, and Ya Cheng. + ing a hybrid mode-wavelength division multiplexing and Lithium niobate microring with ultra-high Q factor above + photonic nonvolatile memory. Optics Express, 30:37051, 108 . Chin. Opt. Lett., 20(1):011902, 2022. + 2022. [43] Flexcompute Inc. Tidy3D: electromagnetic simula- +[24] Weizhen Yu, Shuang Zheng, Zhenyu Zhao, Bin Wang, tion software. https://www.flexcompute.com/tidy3d/, + and Weifeng Zhang. Reconfigurable low-threshold all- 2024. v2.10; cloud GPU FDTD. Accompany- + optical nonlinear activation functions based on an add- ing notebook: https://www.flexcompute.com/tidy3d/ + drop silicon microring resonator. IEEE Photonics Journal, community/notebooks/CascadedMRRTFLN/. + 14(6):1–7, 2022. [44] John K. Doylend, Paul E. Jessop, and Andrew P. Knights. +[25] Bahaa E. A. Saleh and Malvin C. Teich. Fundamentals Silicon photonic dynamic optical channel leveler with + of Photonics. Wiley, Hoboken, NJ, 2 edition, 2007. external feedback loop. Optics Express, 18(13):13805– +[26] Vı́tor R. Almeida, Carlos A. Barrios, Roberto R. 13812, 2010. + Panepucci, and Michal Lipson. All-optical control of light [45] Karl J. Åström and Richard M. Murray. Feedback Systems: + on a silicon chip. Nature, 431(7012):1081–1084, 2004. An Introduction for Scientists and Engineers. Princeton +[27] Qianfan Xu, Bradley Schmidt, Sameer Pradhan, and University Press, Princeton, NJ, 2008. + Michal Lipson. Micrometre-scale silicon electro-optic mod- [46] Amnon Yariv, Yong Xu, Reginald K. Lee, and Axel + ulator. Nature, 435(7040):325–327, 2005. Scherer. Coupled-resonator optical waveguide: a proposal +[28] Kishore Padmaraju and Keren Bergman. Resolving the and analysis. Optics Letters, 24(11):711–713, 1999. + thermal challenges for silicon microring resonator devices. [47] Meisam Bahadori, Yansong Yang, Ahmed E. Hassanien, + Nanophotonics, 3:269–281, 2014. Lynford L. Goddard, and Songbin Gong. Ultra-efficient +[29] Erwen Li, Behzad Ashrafi Nia, Bokun Zhou, and Alan X. and fully isotropic monolithic microring modulators in + Wang. Transparent conductive oxide-gated silicon mi- a thin-film lithium niobate photonics platform. Optics + croring with extreme resonance wavelength tunability. Express, 28(20):29644–29661, 2020. + Photonics Research, 7(4):473, 2019. [48] Abu Naim R. Ahmed, Shouyuan Shi, Mathew Zablocki, +[30] Lahiru Jayatilleka et al. Post-fabrication trimming of Peng Yao, and Dennis W. Prather. Tunable hybrid sil- + silicon photonic ring resonators at wafer-scale. Journal icon nitride and thin-film lithium niobate electro-optic + of Lightwave Technology, 39:5083–5088, 2021. microresonator. Optics Letters, 44(3):618, 2019. +[31] Elliott W. Cheney. Introduction to Approximation Theory. + McGraw–Hill, New York, 1966. +[32] Alec Radford et al. Language models are unsupervised + multitask learners. Technical report, OpenAI, 2019. +[33] Hugging Face. distilgpt2 model card, 2025. accessed + 2026-02-21. +[34] Andrej Karpathy. Tiny shakespeare dataset (char-rnn), + 15 + + SUPPLEMENTARY INFORMATION + +Supplementary material accompanying “Photonic Exponential Approximation via Cascaded TFLN Microring Resonators +toward Softmax.” + + + S0. RIGOROUS DERIVATION AND VALIDITY SCOPE + + This section derives the depth-scaling relations and screening bounds used in the main text, and states the assumptions +under which they apply, together with validity scope and failure cases. We separate proved statements (Lemma, +Proposition, Theorem) from heuristic engineering estimates that rely on empirical calibration. + + + S0.1 Assumptions + +Assumption 1 (Lorentzian single-ring transfer). Each ring k has a normalized add-to-drop transmission of the form +Tk (I) = [1 + (ak + bI)2 ]−1 , where ak ∈ R is the dimensionless static detuning, b > 0 is the common normalized +sensitivity, and I ≥ 0 is a nonnegative control-signal amplitude. +Assumption 2 (Multiplicative cascade). The N rings are cascaded in a serial add-drop topology (the drop output of +ring k feeds the add input of ring k+1), and the probe is sufficiently weak that cross-ring and nonlinear probe-induced + QN +effects are negligible; thus the total normalized transmission is y(I) = k=1 Tk (I). +Assumption 3 (Identical-detuning family). All rings share the same static detuning: a1 = · · · = aN ≡ a. This reduces +the design space to (a, b) and a global scale C > 0; the scaled output is ỹ(I) = C [1 + (a + bI)2 ]−N . +Assumption 4 (Linear control-to-resonance mapping). Within the operating range I ∈ [0, L], the resonance shift is +a linear function of the control-signal amplitude (Eq. (10) of main text), i.e., higher-order detuning nonlinearity is +negligible. +Assumption 5 (Finite interval and bounded L). The approximation target is f (I) = eI−L on a finite interval +I ∈ [0, L] with L > 0 determined by the input batch (L = max(x) − min(x)). The depth-scaling results are derived for +fixed, finite L. +Assumption 6 (Flank-centered operating regime). The design uses the “flank-centered” initialization: a+b(L/2) = −1 +(midpoint on the Lorentzian half-maximum) and N b = 1 (slope matching). This places the operating point in the +steepest-slope region of the Lorentzian, where the log-transfer is most nearly linear. + + + S0.2 Rigorous results + + Throughout, define the log-domain residual + + r(I) ≡ ln ỹ(I) − (I − L) = ln C − N ln 1 + (a + bI)2 − (I − L), + + (S0.1) + +and the worst-case log-error E∞ = supI∈[0,L] |r(I)|. We set C to the minimax-optimal value ln C ⋆ = − maxI g(I) + + +minI g(I) /2, where g(I) = ln y(I) − (I − L), throughout. +Lemma 1 (Slope bound — rigorous). Under Assumptions 1–3, for every I ≥ 0, + + d + ln y(I) ≤ N |b|. + dI + +Proof. From Assumption 3, ln y(I) = −N ln 1 + (a + bI)2 . Differentiating: + + + d 2b(a + bI) + ln y(I) = −N . + dI 1 + (a + bI)2 + +Let u ≡ a + bI ∈ R. The function h(u) = 2u/(1 + u2 ) satisfies |h(u)| ≤ 1 for all u ∈ R (since 1 + u2 ≥ 2|u| by AM–GM). +Therefore |d(ln y)/dI| = N |b| |h(u)| ≤ N |b|. + 16 + +Remark 1 (Necessary condition for approximation). Since the target ln f (I) = I − L has constant slope +1, a +necessary condition for the cascade log-transfer to track this slope at any point is N |b| ≥ 1. This is Eq. (23) of the +main text and is a rigorous (not heuristic) necessary condition. +Proposition 1 (Log-domain Taylor expansion at flank center). Under Assumptions 1–6, define I0 = L/2 and +δ = I − I0 . Then + δ3 + ln ỹ(I) = const + δ + + R4 (δ), (S0.2) + 6N 2 +where |R4 (δ)| ≤ M4 δ 4 with M4 = N |b|4 · sup|δ|≤L/2 |q (4) (u(δ))| and q(u) = − ln(1 + u2 ). In particular, the quadratic +term vanishes identically at the flank point u0 = a + bI0 = −1. +Proof. Set u(δ) = a + b(I0 + δ) = −1 + bδ (using a + bI0 = −1). Define ϕ(u) = − ln(1 + u2 ). Then ln y(I) = N ϕ(u(δ)) +and u′ (δ) = b. Compute derivatives of ϕ at u0 = −1: + 2u + ϕ′ (u) = − , ϕ′ (−1) = 1, + 1 + u2 + 2(u2 − 1) + ϕ′′ (u) = , ϕ′′ (−1) = 0, + (1 + u2 )2 + 4u(3 − u2 ) −4(−1)(3 − 1) + ϕ′′′ (u) = , ϕ′′′ (−1) = = 1. + (1 + u2 )3 (1 + 1)3 +By the chain rule, writing F (δ) = N ϕ(u(δ)): + F ′ (0) = N b ϕ′ (−1) = N b = 1, + F ′′ (0) = N b2 ϕ′′ (−1) = 0, + 1 + F ′′′ (0) = N b3 ϕ′′′ (−1) = N b3 = + , + N2 +where we used b = 1/N from Assumption 6 in the last step. Hence the Taylor expansion with the minimax-optimal C +is + δ2 1 δ3 + ln ỹ(I) = const + δ + 0 · + 2· + R4 (δ). + 2 N 6 +Subtracting the target δ (the linear part of I − L around I0 ) gives a leading residual δ 3 /(6N 2 ). The remainder is +bounded by the standard Taylor remainder estimate. +Theorem 1 (Heuristic depth-scaling law). Under Assumptions 1–6 and ignoring the fourth-order remainder R4 , the +leading-order worst-case log-error on I ∈ [0, L] satisfies + 3 + (leading) 1 L L3 + E∞ ∼ = . (S0.3) + 6N 2 2 48 N 2 + (leading) +Setting E∞ ≤ εlog = ln(1 + ε) and solving for N gives + L3/2 + N ≥ p . (S0.4) + 48 εlog +Derivation (heuristic). From Proposition 1, the residual with respect to the target is dominated by δ 3 /(6N 2 ) for +|δ| ≤ L/2. The maximum of |δ|3 on [−L/2, L/2] is (L/2)3 . Setting the bound equal to εlog and solving: + L3 L3/2 + ≤ εlog =⇒ N≥p . + 48 N 2 48 εlog + √ +With 1/ 48 ≈ 0.144, and accounting for the fact that the minimax-optimal residual is typically smaller than the +one-sided Taylor bound by a factor of ∼ 2 (equi-oscillation), the effective prefactor becomes κ ≈ 0.07, yielding the + √ +main-text engineering estimate Eq. (28): N ≈ ⌈max(1/bmax , κ L3/2 / εlog )⌉. +Remark 2 (Status of Theorem 1). This is a heuristic scaling law, not a rigorous minimax guarantee. The +derivation truncates the Taylor series at third order and approximates the equi-oscillation factor empirically (κ ≈ 0.07). +For a rigorous bound one would need explicit control of R4 over the full interval [0, L], which depends on L, N , and + √ +higher derivatives of the Lorentzian; we do not claim such a bound here. The scaling N ∼ L3/2 / εlog is supported by +numerical evidence (Table I) but should be treated as an engineering design rule. + 17 + + S0.3 Derivation of the conservative screening bound + + We now derive the conservative screening bound (Eqs. S0.7–S0.8 below), which is stated inline in Sec. II of the main +text. +Proposition 2 (Conservative log-error bound). Under Assumptions 1–5 (identical detuning, but not restricted to the +flank-centered choice), fix b > 0 and choose the normalization ỹ(L) = 1. Define ϕ(u) = − ln(1 + u2 ) and write + + ln ỹ(I) = N ϕ(a + bI) − ϕ(a + bL) . + +The target in this normalization is (I − L). Denoting the residual r(I) = ln ỹ(I) − (I − L), we have r(L) = 0 and +r(0) = N [ϕ(a) − ϕ(a + bL)] + L. + For any choice of a such that the operating range {a + bI : I ∈ [0, L]} lies in the region where ϕ is concave (i.e., +ϕ′′ (u) ≤ 0 throughout), the worst-case log-error satisfies + + N ∥ϕ′′ ∥∞ b2 L2 N ϕ′ (a + bL) · b − 1 + E∞ ≤ + · L, (S0.5) + 8 2 + +where ∥ϕ′′ ∥∞ = supu∈[a, a+bL] |ϕ′′ (u)|. +Derivation sketch. Write h(I) ≡ N ϕ(a + bI). The slope is h′ (I) = N b ϕ′ (a + bI). At I = L, we want the slope to +match the target slope 1; define the slope mismatch ∆s ≡ h′ (L) − 1 = N b ϕ′ (a + bL) − 1. By the mean-value theorem +on [0, L]: + Z L + 1 − h′ (t) dt. + + r(I) − r(L) = r(I) = h(I) − h(L) − (I − L) = + I + RL +Write 1 − h′ (t) = (1 − h′ (L)) + (h′ (L) − h′ (t)) = −∆s + t h′′ (s) ds. Since h′′ (s) = N b2 ϕ′′ (a + bs), we bound +|h′′ (s)| ≤ N b2 ∥ϕ′′ ∥∞ . Integrating twice and applying the triangle inequality gives (S0.5). +Corollary 1 (Main-text conservative bound). Under slope matching at I = L (i.e., N b ϕ′ (a + bL) = 1, so ∆s = 0), +and using ∥ϕ′′ ∥∞ ≤ 2 (which holds since |ϕ′′ (u)| = |2(u2 − 1)/(1 + u2 )2 | ≤ 2 for all u ∈ R), the bound simplifies to + + N b2 L 2 + E∞ ≤ . (S0.6) + 4 +Using b = 1/N (the slope-matching choice from N b = 1) gives E∞ ≤ L2 /(4N ). If instead we retain a general b but add +the penalty from imperfect slope matching (e.g., from the constraint b ≤ bmax ), a combined conservative bound is + + L2 1 + E∞ ≤ + 2 , (S0.7) + 4N 2b N +which provides a conservative heuristic bound on the log-error. Setting this ≤ ln(1 + ε) and solving for N yields the +conservative screening depth: + 2 + L /4 + 1/(2b2 ) + + Nsafe ≥ . (S0.8) + ln(1 + ε) + +Remark 3 (Status of the conservative bound). Equation (S0.7) is a conservative heuristic design rule. It is +conservative because: (i) we use a global upper bound ∥ϕ′′ ∥∞ ≤ 2 instead of the actual curvature, (ii) we do not exploit +the minimax-optimal C shift. It is heuristic (not a formal guarantee) because: (i) the derivation assumes the operating +range lies in the concavity region of ϕ, which may not hold for all detuning choices; (ii) the second term 1/(2b2 N ) +arises from a simplified penalty model for flank-curvature mismatch that has not been proved to be a rigorous upper +bound in all parameter regimes. Nsafe from Eq. (S0.8) is therefore a screening estimate, suitable for preliminary +design-space exploration but not a certified minimax guarantee. + + + S0.4 Validity scope and failure cases + + The derivations above hold under the stated assumptions. We now identify the regimes where each assumption may +break down. + 18 + +(V1) Lorentzian model (A1). The single-ring Lorentzian form T = [1+(a+bI)2 ]−1 is a near-resonance approximation + valid when the probe frequency is within a few linewidths of the resonance. Far from resonance, higher-order + dispersion, Fano interference, or multi-mode effects introduce deviations. Failure case: operation with very large + detuning (|a + bI| ≫ 1 across the full interval), where the Lorentzian tails may not be accurate for high-Q rings. + +(V2) Multiplicative cascade (A2). Requires that inter-ring reflections and back-scattering are negligible (forward- + propagating coupling only, i.e. negligible back-reflection at each inter-ring junction). Failure case: very high ring + count N with non-negligible back-reflection per stage, which can produce Fabry–Pérot-like ripples in the cascade + transfer function. + +(V3) Identical-detuning family (A3). The Taylor expansion and conservative bound both assume a1 = · · · = aN . + In practice, fabrication variations introduce per-ring detuning spread σa . The Monte Carlo analysis in Sec. S8 + quantifies robustness, but the analytical bounds (S0.2)–(S0.8) are strictly valid only for identical detuning. + (0) +(V4) Linear control-to-resonance mapping (A4). The linearized model ω0 (I) = ω0 + ηI introduces systematic + error at large control amplitudes. For carrier-injection (free-carrier plasma effect) or thermal tuning over wide + ranges, second-order nonlinearity in the control-to-detuning mapping can exceed 1%. Failure case: large L + requiring a control swing exceeding the linearity range of the tuning mechanism. + +(V5) Finite interval (A5). All bounds scale with L (typically as L2 or L3/2 ). As L → ∞, N grows without bound + and insertion loss accumulates (∼ N · ILstage ), eventually degrading the probe SNR below the useful regime. + There is no finite N that works for all L simultaneously. Practical regime: L ≲ 10–12 (consistent with Leff at + p90–p95 from Sec. S3) is the primary target; L ≳ 16 requires N ≳ 30 even for moderate tolerance, pushing loss + budgets. + +(V6) Flank-centered initialization (A6). The Taylor-based scaling (Theorem 1) relies on the cancellation + ϕ′′ (−1) = 0 at the half-maximum point. If the operating point deviates (e.g., due to fabrication offset pushing + a + bI0 away from −1), a nonzero quadratic residual appears and the effective scaling worsens to E∞ ∼ L2 /N + rather than L3 /N 2 . Mitigation: heater/bias trimming to restore the flank condition. + + + S0.5 Mapping to main-text equations + +For reference, the results derived here correspond to the following main-text equations: + + • Slope bound (Lemma 1): rigorous; corresponds to main-text Eqs. (22)–(23). This is a guaranteed necessary + condition. + + • Engineering N -estimate (Theorem 1): heuristic scaling with empirical prefactor κ ≈ 0.07; corresponds to + main-text Eq. (28). This is a heuristic design rule calibrated against numerical fits. + + • Conservative bound (Corollary 1): conservative but not rigorously certified as a minimax upper bound; derived + as Eq. (S0.7) in this supplement, stated inline in Sec. II. This is a conservative heuristic screening condition. + + • Nsafe (Corollary 1, Eq. S0.8): the safe screening depth derived from the conservative bound; derived as Eq. (S0.8) + in this supplement, stated inline in Sec. II. This is a conservative backstop estimate for preliminary design. + +Summary of guarantee status: +Result Status Main-text Eq. +Slope bound N |b| ≥ 1 Rigorous (proved) (23) + √ +Scaling N ∼ κL3/2 / εlog Heuristic (Taylor truncation + empirical κ) (28) +Bound E∞ ≤ L2 /(4N ) + 1/(2b2 N ) Conservative heuristic (S0.7) +Nsafe screening depth Conservative backstop (S0.8) + + + S1. DEPTH-SCALING DERIVATION AND CONSERVATIVE SCREENING BOUND + + This section provides the detailed derivations underlying the depth-scaling relations and conservative screening +bounds summarized in the main text (Sec. II). These results complement the rigorous treatment in Sec. S0. + 19 + + S1.1 Local expansion and exponential-like behavior + + To provide immediate local intuition (without changing the global minimax objective), let δ = I − I0 around the +flank-centered point I0 = L/2 and impose a + bI0 = −1. With the local normalization C = 2N (so that ỹ(I0 ) = 1), a +third-order expansion of ỹ(I) = C[1 + (a + bI)2 ]−N gives + + N 2 2 2 N (N 2 − 1) 3 3 + ỹ(I) ≈ 1 + N b δ + b δ + b δ + O(δ 4 ), (S1.1) + 2 6 +so with b ∼ 1/N , the linear and quadratic coefficients align with those of eδ = 1 + δ + δ 2 /2 + δ 3 /6 + · · · , explaining +why the initialization is already close before refinement. + + + S1.2 Log-domain analysis and scaling derivation + + For depth scaling, the logarithmic domain is more transparent. Under the same flank centering (a + bI0 = −1), +expand around I0 = L/2 with δ = I − I0 to obtain + + N b3 3 + ln ỹ(I) = const + N b δ + δ + O(δ 4 ). (S1.2) + 6 +At a + bI0 = −1, the quadratic term cancels identically in the log expansion; imposing slope matching (N b = 1) gives + + δ3 + ln ỹ(I) = const + δ + + O(δ 4 ). (S1.3) + 6N 2 +Hence the leading log-domain residual scales as r(δ) ∼ δ 3 /N 2 . Over I ∈ [0, L] with |δ| ≤ L/2, this implies E∞ ∼ L3 /N 2 . +Requiring E∞ ≤ εlog leads to + + L3/2 + N∝√ , (S1.4) + εlog + +which explains the scaling used in the main-text engineering estimate (Eq. (28)). This derivation is heuristic (not a +formal guarantee), and the prefactor remains platform- and fitting-criterion dependent. + + + S1.3 Conservative upper bound and screening depth + + For fixed b and the identical-detuning family (a1 = · · · = aN ≡ a), one can write a conservative heuristic condition +for achieving a prescribed log-tolerance. A simple normalization is to enforce ỹ(L) = 1 (matching the target f (L) = 1). +For a particular constructive choice of a that keeps (a + bI) large and negative across [0, L], one can bound the +worst-case log-error as + + L2 1 + E∞ ≤ + 2 . (S1.5) + 4N 2b N +(This is a conservative rule of thumb; obtaining a formal guarantee would require a separate proof.) As a screening +estimate (not a formal guarantee), one may use + 2 + L /4 + 1/(2b2 ) + + N ≥ . (S1.6) + ln(1 + ε) + +While this bound is typically pessimistic, it provides a conservative backstop-style estimate for preliminary design +screening. The rigorous derivation of these bounds, including the concavity conditions and slope-matching assumptions, +is given in Sec. S0.3. + 20 + + S2. WORKED EXAMPLE AND EMPIRICAL LOGIT-RANGE CALIBRATION + + This section provides the detailed worked example for the input-to-output mapping and the empirical logit-range +calibration tables referenced in the main text (Sec. III). + + + S2.1 Worked input-to-output mapping example + + As a worked example, consider + + x = [−3.2, 1.2, 4.8, −0.9]. (S2.1) + +Compute m = max xn = 4.8. Then u = x − m = [−8.0, −3.6, 0, −5.7] and L = − min un = 8.0. The mapped +control-signal levels are + + I = u + L = [0, 4.4, 8.0, 2.3], (S2.2) + +and the required normalized exponentials are exn −m = eun = eIn −L . Using the fitted model directly, + N + 1 Y + Tk (In ) = , y(In ) = Tk (In ). + 1 + (ak + bIn )2 + k=1 + +Under the identical-detuning fit (a1 = · · · = aN ≡ a), this becomes + N + 1 + ỹ(In ) = C y(In ) = C . + 1 + (a + bIn )2 +For the re-fitted parameters used in this example, + + a = −1.4588, b = 0.10202, + (S2.3) + N = 10, C = 3.0896 × 101 . + +which gives + N + 1 + ỹ(In ) = C , + 1 + (a + bIn )2 + (S2.4) + ≈ [3.44 × 10−4 , 2.73 × 10−2 , + 9.74 × 10−1 , 3.26 × 10−3 ]. + + For reference, the corresponding target terms are + + In − L = [−8.0, −3.6, 0, −5.7], (S2.5) + +and + In −L + e ≈ 3.35 × 10−4 , 2.73 × 10−2 , + (S2.6) + 1.00, 3.35 × 10−3 . + + + + + + S2.2 Effective-range percentiles and clipping calibration + + We first estimate the logit range observed in data and then choose clipping accordingly. From two autoregressive +Transformers (distilgpt2 and gpt2) and two public corpora (Tiny Shakespeare and Pride and Prejudice) [1–5] at context +length 128, the effective range + + Leff,α = max(log pkept ) − min(log pkept ), α = 0.999, (S2.7) + +fell in a relatively narrow band, summarized in Table S2. + 21 + + TABLE S1: Example (N = 10): approximating exn −m = eIn −L using ỹ(I) = C[1 + (a + bI)2 ]−N with parameters + re-fitted on I ∈ [0, 8.0] using the same minimax pipeline. + + xn In target exn −m approx ỹ(In ) rel. err. + −4 −4 +−3.2 0.0 3.3546 × 10 3.4443 × 10 2.673% + 1.2 4.4 2.7324 × 10−2 2.7325 × 10−2 0.004% + 4.8 8.0 1.0000 0.9739 2.608% +−0.9 2.3 3.3460 × 10−3 3.2585 × 10−3 2.614% + + + TABLE S2: Effective-range percentiles (Leff,0.999 ) at context length 128. + + Percentile All runs (4 runs) GPT-2 + p50 6.92–7.23 7.09–7.23 + p90 8.60–8.75 8.73–8.75 + p95 8.97–9.12 9.06–9.12 + p99 9.50–9.69 9.58–9.69 + + + We then test clipping on the same rows with + + Ecum (t) = 12 ∥softmax(u(t) ) − softmax(u)∥1 , + (S2.8) + u(t) = max(u, t), u = s − max(s). + +and require p99{Ecum } ≤ 10−3 (0.1% budget). This criterion is satisfied at t = −12 (p99 ≈ 4.27 × 10−4 ) and violated +at t = −11 (p99 ≈ 1.24 × 10−3 ), so we set t∗ = −12 (Nclip = 12). + In practice, we (i) estimate an effective L from data, (ii) verify that fixed clipping keeps softmax error small, and (iii) +choose representative design points (e.g., L ≈ 8 or L ≈ 12) while treating the clipped tail as negligible. Full protocol +details, clipping-sweep tables/plots, and per-run statistics are provided in Sec. S3. + + + S2.3 Illustrative synthetic range map + √ + As a design-space reference, we consider synthetic logit-range regimes using L = max(x) − min(x) after QK ⊤ / dk +scaling. These regimes are illustrative rather than corpus-level percentiles; using the same fitting pipeline, Table S3 +summarizes achievable approximation error versus depth. + + TABLE S3: Synthetic softmax logit-range regimes (L = max(x) − min(x)) and fitted worst-case relative error + (design-space illustration; not intended as corpus-level statistics). + +L regime N =5 N = 10 N = 20 N = 30 + L=8 10.9% 2.68% 0.67% 0.30% + L = 12 40.0% 9.25% 2.27% 1.01% + L = 16 113% 23.0% 5.44% 2.41% + + + Table S3 suggests a simple rule of thumb: the required depth depends mainly on the target L regime. Near L ≈ 8, +moderate depth reaches a few-percent error, whereas L ≳ 12 typically requires deeper cascades to approach < 1% +error. + We include Table S3 as a synthetic design map rather than an empirical benchmark. + 22 + + S3. EMPIRICAL LOGIT-RANGE EXTRACTION FROM REAL TRANSFORMER RUNS + + We extracted empirical attention-logit ranges from real model runs to complement the synthetic L-regime map in +the main text. We used two open-source autoregressive Transformers (distilgpt2 and gpt2) and two public corpora +(Tiny Shakespeare and Pride and Prejudice), with context length 128 and causal masking. For each valid attention +row, if p = softmax(s) then the raw range is + Lraw = max(s) − min(s) = max(log p) − min(log p), (37) +where max/min are taken over valid causal keys only. Because very small tail probabilities can dominate min(log p), +we additionally report an effective range: + Leff,α = max(log pkept ) − min(log pkept ), (38) +where keys are sorted by attention weight and retained until cumulative mass reaches α = 0.999. + To stay within a 16 GB RAM budget, we processed one model at a time, batch size 1, fixed windowing (stride 128), +and streaming histogram quantiles. Observed process RSS stayed below 1.24 GB in these runs. + + TABLE S4: Empirical global logit-range percentiles from real model–dataset runs (context length 128): raw vs + effective (α = 0.999). + + Model Dataset raw p95 raw p99 Leff p50 Leff p90 Leff p95 Leff p99 + distilgpt2 tiny shakespeare 22.82 69.00 7.10 8.60 8.97 9.50 + distilgpt2 pride prejudice 21.76 68.60 6.92 8.60 9.03 9.57 + gpt2 tiny shakespeare 25.48 43.34 7.23 8.73 9.06 9.58 + gpt2 pride prejudice 24.13 40.92 7.09 8.75 9.12 9.69 + + For quick linkage to the main manuscript: the effective-range summary quoted in the main text corresponds to this +table (all runs: p50 = 6.92–7.23, p90 = 8.60–8.75, p95 = 8.97–9.12, p99 = 9.50–9.69), and the GPT-2 subset is p50 += 7.09–7.23, p90 = 8.73–8.75, p95 = 9.06–9.12, p99 = 9.58–9.69. +Clipping-validity sweep (additional justification). To test whether practical clipping magnitudes can be used +without materially changing softmax outputs, we evaluated a thresholded-logit approximation. For each row, define +u = s − max(s) and, for threshold t ≤ 0, + u(t) = max(u, t), p(t) = softmax(u(t) ). (39) +We report the cumulative softmax error + 1 (t) + p −p , + Ecum (t) = (40) + 2 1 +then sweep t ∈ {−14, −13, . . . , −6} and compute p50/p90/p95/p99 of Ecum over all extracted rows. + + TABLE S5: Global clipping-validity sweep: percentile statistics of Ecum (t) versus clipping threshold t. + + t p50 p90 p95 p99 + −5 −5 −5 + −14 2.53 × 10 4.55 × 10 4.80 × 10 5.18 × 10−5 + −5 −5 −5 + −13 2.69 × 10 4.85 × 10 7.38 × 10 1.48 × 10−4 + −5 −4 −4 + −12 2.99 × 10 1.21 × 10 2.13 × 10 4.27 × 10−4 + −11 3.31 × 10−5 3.95 × 10−4 6.55 × 10−4 1.24 × 10−3 + −10 3.72 × 10−5 1.28 × 10−3 2.01 × 10−3 3.58 × 10−3 + −9 4.41 × 10−5 4.04 × 10−3 6.11 × 10−3 1.03 × 10−2 + −8 2.25 × 10−4 1.26 × 10−2 1.83 × 10−2 2.91 × 10−2 + −7 2.76 × 10−3 3.85 × 10−2 5.30 × 10−2 7.89 × 10−2 + −6 1.88 × 10−2 1.11 × 10−1 1.43 × 10−1 1.95 × 10−1 + + Under a conservative budget criterion p99{Ecum } ≤ 10−3 , the least negative admissible threshold in this sweep +is t∗ = −12 (p99 ≈ 4.27 × 10−4 ). Equivalently, the operational clipping magnitude is Nclip ≡ −t∗ = 12. Notably, +this is closely aligned with the empirical effective-range scale (Table S4: p99 of Leff,0.999 up to ≈ 9.69), indicating +that clipping-constrained implementation and effective-range statistics operate in the same order-of-magnitude range +budget. This supports using a practical clipping magnitude comparable to the design range scale (L ≈ Nclip ) while +keeping aggregate softmax distortion below 0.1%. + 23 + + + + + FIG. S1: Global CDFs of raw Lraw (dashed) and effective Leff,0.999 (solid) for the four model–dataset runs. + + + + +FIG. S2: Percentile curves of cumulative softmax error Ecum (t) versus clipping threshold t. The dashed line marks the + 0.1% budget (10−3 ). + 24 + + S4. FDTD METHODOLOGY DETAILS AND X-CUT bV DERIVATION + + This section provides the detailed FDTD simulation methodology, the step-by-step X-cut arc electrode voltage +sensitivity derivation, and the full cascade optimization table referenced in the main text (Sec. IV–V). + + + S4.1 z-refined 3-fix simulation strategy + + For thin-film LiNbO3 structures, special care is required in the vertical (z) direction due to the high index contrast +between LiNbO3 (no ≈ 2.21) and SiO2 (n ≈ 1.44) and the sub-micron film thickness. We apply a “z-refined 3-fix” +strategy: + 1. Ordinary index correction: the material model uses the corrected ordinary refractive index no appropriate + for the TE mode in X-cut geometry, rather than the extraordinary index ne that governs TM propagation; + 2. z-span expansion: the simulation z-span is extended beyond the minimal waveguide region to include sufficient + substrate and superstrate so that evanescent field tails are captured without PML truncation artifacts; + 3. Auto-mesh: accuracy level 3; conformal variant 1 meshing is enabled, and no manual mesh override is applied. + The resulting vertical grid spacing in the slab region is approximately 55 nm, providing ∼2 cells across the 100 nm + slab. +This refinement strategy is critical for obtaining converged results in TFLN ring resonators, where the high-Q spectral +features are sensitive to numerical dispersion in under-resolved thin films [6]. Table S6 lists the full simulation +parameters. + + TABLE S6: 3D FDTD simulation parameters (Lumerical). + +Parameter Value +Solver Lumerical 3D FDTD +Mesh type Conformal variant 1 +Mesh accuracy 3 (auto-mesh) +z-mesh override None (auto-mesh) +Simulation time 50 ps +Auto shutoff 1 × 10−6 +Wavelength range 1530 nm to 1570 nm +Grid size 532 × 816 × 44 +Source Broadband mode source (TE0 ) + + + + + S4.2 X-cut arc electrode bV step-by-step derivation + + For the X-cut circular ring with lateral S–G arc electrodes (Table II), the crystal Z-axis (c-axis) is oriented at 45◦ +from the horizontal axis in the substrate plane. At azimuthal angle θ around the ring, the projection of the lateral +electric field onto the Z-axis is proportional to cos(θ − 45◦ ). The cos(θ − 45◦ ) = 0 boundaries fall at θ = 135◦ and +θ = 315◦ , naturally separating the bus-waveguide coupling regions from the electrode regions. Each ring carries a full +semicircular arc electrode on the side opposite to its coupling points. By the substitution φ = θ − 45◦ , the effective +EO fill factor is + Z Z +π/2 + 1 1 1 +π/2 1 + fEO = | cos(θ − 45◦ )| dθ = cos φ dφ = sin φ −π/2 = ≈ 0.318. (S4.1) + 2π semicircle 2π −π/2 2π π +The 45◦ rotation ensures that the electrode semicircle does not overlap with the coupling points, while the fill factor +integral is identical to the standard cos θ case by the change of variable. + The lateral S–G electrodes have gap gel = 5 µm, giving an effective electrode–waveguide distance deff ≈ gel /2 = 2.5 µm. +The lateral field geometry yields an EO overlap factor ΓEO = 0.7, compared to 0.5 for a vertical electrode configuration. + The refractive index change per volt in the electrode-covered section is + ∆neff 1 ΓEO 1 0.7 + = − n3e r33 = − × 2.1383 × 30.9 × 10−12 × = −4.226 × 10−5 V−1 . (S4.2) + V 2 deff 2 2.5 × 10−6 + 25 + +The corresponding resonance wavelength shift is + dλ0 1550 × 4.226 × 10−5 + = = 28.48 pm V−1 , (S4.3) + dV straight 2.30 + +giving an intrinsic (straight-section) voltage sensitivity of + 2QL dλ0 2 × 15,500 + bstraight + V = = × 0.02848 = 0.570 V−1 . (S4.4) + λ0 dV straight 1550 +However, only the arc-electrode portion of the ring circumference contributes to the round-trip phase shift. The +effective voltage sensitivity is therefore + 1 + bV = bstraight + V × fEO = 0.570 × ≈ 0.182 V−1 . (S4.5) + π +A 1 V applied voltage shifts the normalized detuning by ∆a ≈ 0.182. Despite the fill-factor penalty (fEO = 1/π ≈ 0.318), +the X-cut arc design benefits from a smaller effective electrode distance (2.5 µm vs. 4 µm for vertical configurations) +and a higher overlap factor (0.7 vs. 0.5), which partially compensate the reduced active length. + + + S4.3 Full cascade optimization table + + Table S7 presents the complete optimization results for the standard dynamic range L = 8 (corresponding to +e8 ≈ 2981, i.e., 34.7 dB), covering all cascade depths from N = 5 to N = 30. + + TABLE S7: Cascade optimization results for L = 8. The bias voltage Vbias = |a|/bV sets the DC offset, and +Vctrl = bL/bV is the maximum control voltage at I = L. Voltages computed with bV = 0.182 V−1 (FDTD-calibrated + best resonance QL = 15,500). + +N a b E∞ εmax (%) Vbias (V) Vctrl (V) + 5 −2.0789 0.21658 0.1035 10.91 11.4 9.5 + 8 −1.5959 0.12896 0.0412 4.20 8.8 5.7 +10 −1.4588 0.10202 0.0265 2.68 8.0 4.5 +12 −1.3731 0.08450 0.0184 1.86 7.5 3.7 +15 −1.2914 0.06726 0.0118 1.19 7.1 3.0 +17 −1.2543 0.05923 0.0092 0.92 6.9 2.6 +20 −1.2136 0.05025 0.0067 0.67 6.7 2.2 +25 −1.1685 0.04013 0.0043 0.43 6.4 1.8 +30 −1.1392 0.03341 0.0030 0.30 6.3 1.5 + + + Key thresholds for the minimum number of rings at various error targets are: + • ε < 10%: N ≥ 6, + • ε < 5%: N ≥ 8, + • ε < 2%: N ≥ 12, + • ε < 1%: N ≥ 17, + • ε < 0.5%: N ≥ 24. +These thresholds are independent of the quality factor Q, since the minimax approximation operates entirely in +normalized detuning space. The Q factor affects only the physical voltage required to achieve the necessary detuning +range, through bV . + + + S4.4 Lorentzian fit validation + + Figure S3 shows the Lorentzian fit to the FDTD drop-port resonance near λ = 1566 nm. The analytical Lorentzian +Tdrop (∆λ) = A/[1 + (2Q∆λ/λ0 )2 ] with QL = 15,500 closely tracks the FDTD data, validating the single-ring transfer +function model used in the cascade analysis. + 26 + + + + + FIG. S3: Lorentzian fit to the FDTD drop-port resonance. Markers: FDTD data; solid line: Lorentzian fit. The + extracted quality factor is QL = 15,500 with FWHM = 101 pm. + + + S4.5 Eigenmode (FDE) analysis of theoretical Qi + + To quantify how far below the physical limit the FDTD-extracted Qi = 38,800 lies, we perform a two-dimensional +finite-difference eigenmode (FDE) analysis of the bent rib waveguide cross-section using Lumerical MODE Solutions. + a. Setup. The FDE solver models the cross-section of the rib waveguide at the design bend radius R = 20 µm +and wavelength λ = 1550 nm, with perfectly matched layer (PML) boundaries on all four edges. The geometry is +identical to the 3D FDTD model: 600 nm total LiNbO3 (no = 2.211, lossless dielectric), 100 nm slab, 500 nm rib etch, +waveguide width W = 1.4 µm, on a 2 µm SiO2 substrate (n = 1.444) with air cladding. The mesh is set to 300 × 300 +cells over a 6 µm × 3 µm cross-section, yielding effective grid spacings ∆x ≈ 20 nm and ∆y ≈ 10 nm—substantially +finer than the 3D FDTD auto-mesh (55 nm vertical). + b. Complex effective index. The FDE solver returns a complex effective index neff = nr + i ni for each guided +mode, where the imaginary part ni encodes propagation loss. For the fundamental TE mode at R = 20 µm: + neff = 1.9653 + i (4.73 × 10−8 ), (41) + 4π ni + = 0.383 m−1 0.017 dB cm−1 . + + αrad+leak = (42) + λ +Since the material is set as lossless, this α captures only bending radiation loss and substrate leakage through the +100 nm slab. The corresponding quality factor is + 2π ng + Qrad+leak = = 2.43 × 107 , (43) + αrad+leak λ +where ng = 2.354 is the group index from the FDE solver (consistent with the FDTD FSR-derived ng = 2.30; the +small difference arises from the straight-section approximation inherent to 2D FDE). + c. Decomposition into bending and leakage. A separate FDE run with R = 1 mm (effectively straight) yields +Qleak = 2.93 × 107 , isolating the substrate leakage contribution. The pure bending radiation quality factor follows from + 1 1 1 + = − , Qbend = 1.43 × 108 . (44) + Qbend Qrad+leak Qleak +This confirms that bending radiation loss at R = 20 µm is negligible; substrate leakage through the thin slab is the +dominant geometric loss channel. + d. Material absorption. The FDE mode profile yields a confinement factor Γ = 0.887 (fraction of the optical +intensity within the LiNbO3 core and slab regions). The material-absorption-limited quality factor is + 2π ng + Qabs = , (45) + Γ αmat λ + 27 + +where αmat is the bulk material power-attenuation coefficient of LiNbO3 at 1550 nm. Table S8 evaluates Eq. (45) for +representative TFLN absorption values from the literature [6, 7]. + +TABLE S8: Theoretical intrinsic quality factor Qi of the R = 20 µm TFLN ring, decomposed into radiation (Qbend ), + substrate leakage (Qleak ), and material absorption (Qabs ). Sidewall scattering (fabrication-dependent) is excluded. + The total is 1/Qi = 1/Qrad+leak + 1/Qabs with Qrad+leak = 2.43 × 107 . + +Material condition αmat (dB/cm) Qabs Qi (total) +Bulk LiNbO3 (pristine) 0.002 2.3 × 108 2.2 × 107 +High-quality TFLN 0.01 4.7 × 107 1.6 × 107 +Good TFLN 0.03 1.6 × 107 9.5 × 106 +Typical TFLN 0.1 4.7 × 106 3.9 × 106 + + + For high-quality TFLN (αmat ≲ 0.01 dB cm−1 ), the theoretical Qi exceeds 107 —more than 400× higher than the +FDTD-extracted value of 38,800. This confirms that the FDTD result is dominated by numerical mesh artifacts +(approximately two cells across the 100 nm slab), not by physical loss mechanisms. Bending radiation loss at R = 20 µm +is negligible (Qbend = 1.43 × 108 ); the dominant geometric loss channel in the ideal structure is substrate leakage +through the thin slab (Qleak = 2.93 × 107 ). + 28 + + S5. FABRICATED HIGH-Q DESIGN PROJECTIONS + + Reproducing Qi > 105 in three-dimensional FDTD is computationally impractical: at accuracy level 3 the 100 nm +slab requires ∆z ≲ 20 nm to suppress staircase-induced scattering, inflating wall times beyond 30 days per run. The +numerically extracted Qi = 38,800 therefore represents a simulation floor, not a physical one. A two-dimensional +MODE-solver bend analysis confirms Qbend > 4.5 × 107 for R = 20 µm, placing bending radiation loss far below any +realistic intrinsic loss. + Table S9 surveys recent high-Q TFLN microring demonstrations. These studies show that Qi ≥ 9 × 106 has been +demonstrated in X-cut TFLN using multiple fabrication routes, including Ar+ milling, wet etching, and ICP-RIE/CMP- +based processes. + + TABLE S9: Demonstrated intrinsic quality factors in TFLN micro-ring resonators. “EO compatible” indicates + whether the fabrication process preserves electrode patterning capability. + +Ref. Qi R (µm) w (µm) Etch +Zhang [8] 107 80 ∼2 Ar+ mill +Gao [9] 108 100 ∼3 CMP∗ +Zhuang [10] 9×106 100 ∼2 Wet etch +Song [11] 2.9×107 200 4.5 ICP-RIE+CMP + All processes except ∗ are EO-electrode compatible. ∗ CMP-only (no dry etch); subsequent electrode patterning may degrade Qi . + + To project cascade performance into the fabricated regime, we fix Qext = 25,800 (the FDTD-extracted coupling +quality factor at gap = 100 nm) and compute Dmax = [Qi /(Qi + Qext )]2 for three representative intrinsic quality +factors (Table S10). + + N + TABLE S10: Projected cascade transmission for fabricated Qi values at fixed Qext = 25,800. Dmax is the ideal +on-resonance cascade transmission in dB. The minimax approximation error εmax depends only on N and L (not on + Qi ); at N = 20, L = 8: εmax = 0.67% (Table I). + +Projection Qi Dmax N =10 N =20 N =30 +FDTD baseline 3.88×104 0.36 −44.3 −88.5 −132.8 +Conservative 5×105 0.90 −4.4 −8.8 −13.2 +Moderate 106 0.95 −2.2 −4.5 −6.7 +Optimistic 5×106 0.99 −0.44 −0.88 −1.3 + + + Even in the conservative scenario (Qi = 5 × 105 ), Dmax = 0.90 and the N = 10 cascade loss is only −4.4 dB—an +order-of-magnitude improvement over the FDTD baseline. The moderate projection (Qi = 106 ) matches the “fabricated +high-Q” column in Table V. Because Qbend ≈ 4.5 × 107 ≫ Qi for all projections, bending loss is never the bottleneck; +the dominant loss mechanism is sidewall scattering, which is determined entirely by fabrication quality. The literature +values in Table S9 support the view that intrinsic quality factors in the projected range are physically achievable +in TFLN—albeit with wider waveguides (w ≥ 2 µm) and larger ring radii (R ≥ 80 µm) than the present design. +Transferring comparable sidewall quality to our geometry (R = 20 µm, w = 1.4 µm) is an open fabrication challenge; +the projections in Table S10 should be read as design targets contingent on achieving it. + 29 + + S6. INSERTION LOSS BUDGET DETAILS + + For a cascade of N rings, the total insertion loss is modeled as + + ILtot ≈ N · ILstage + ILcoupling , (S6.1) + +where ILstage is the per-ring insertion loss at off-resonance operation and ILcoupling accounts for fiber-to-chip and +chip-to-fiber coupling losses. Using typical loss numbers from the literature [12–16], we consider two scenarios: + + • Optimistic: ILstage = 0.08 dB, ILcoupling = 1.5 dB. Then ILtot ≈ 1.90 dB (N = 5), 2.30 dB (N = 10), 3.10 dB + (N = 20), and 3.80 dB (N = 30). + • Conservative: ILstage = 0.25 dB, ILcoupling = 3.0 dB. Then ILtot ≈ 4.25 dB (N = 5), 5.50 dB (N = 10), + 8.00 dB (N = 20), and 10.5 dB (N = 30). + + In both scenarios, N = 5–10 is manageable for probe-power budgeting, whereas N = 20 and N = 30 require tighter +power budgeting and more amplification margin. Higher ILtot raises the required probe SNR and pushes operation +closer to the detector noise floor, reducing usable dynamic range. + e. Four-component loss breakdown. The total insertion loss of the cascade has four components: + N + 1. On-resonance cascade transmission Dmax (dominant; see Table V); + 2. Inter-ring coupling loss (N − 1) × (−10 log10 ηcoupling ), where ηcoupling is the power transfer efficiency at each + inter-ring bus section. Two-ring FDTD yields ηcoupling ≈ 0.9 for the present diagonal-bus geometry, corresponding + to ∼0.46 dB per inter-ring stage; + 3. Off-resonance propagation loss N × ILstage , where ILstage = 0.08–0.25 dB per ring [12–14, 16]; + 4. Fiber-to-chip coupling loss ILcoupling = 1.5–3.0 dB [15]. + N +Table V presents the ideal on-resonance budget (Dmax only). Including all four components for the present diagonal-bus +layout: in the FDTD-characterized regime (Dmax = 0.36, N = 5), the total loss is approximately 22.2 + 1.8 + 0.4 + 1.5 ≈ +26 dB; in the fabricated high-Q regime (Dmax = 0.95, N = 30), the total loss is 6.7 + 13.3 + 2.4 + 1.5 ≈ 24 dB. The +inter-ring coupling loss dominates in the high-Q regime, underscoring that layout optimization (e.g., adiabatic tapers or +straight-bus coupling) is as important as achieving Dmax ≥ 0.95 through quality-factor improvement. For an optimized +layout with ηcoupling ≥ 0.98 (≤0.09 dB per stage), the N = 30 total loss would reduce to ∼13 dB. + 30 + + S7. ENERGY EFFICIENCY DETAILED DERIVATION + + This section provides the detailed energy-per-operation derivations for both electrical analog exponential circuits +and the photonic MRR cascade, as summarized in the main text (Sec. V). + + + S7.1 Electrical analog exponential circuits + + Three main families of electrical circuits realize the exponential function in the analog domain: + f. BJT translinear / Gilbert cell circuits. The collector current of a bipolar junction transistor is IC = +IS exp(VBE /VT ), providing an intrinsic exponential map [17, 18]. A Gilbert cell multiplier—the core building +block of translinear exponential circuits—dissipates 250–325 µW in typical CMOS/BiCMOS implementations [19]. At +a signal bandwidth of B ≈ 100 MHz, the energy per operation is + P 300 µW + EGilbert = = = 3 pJ. (S7.1) + B 100 MHz + g. CMOS subthreshold exponential circuits. A MOSFET in weak inversion exhibits ID ∝ exp(VGS /nVT ), enabling +direct exponential computation at ultra-low power [18]. A reconfigurable softmax circuit in 180 nm CMOS implements +a 10-input softmax at VDD = 500 mV with P = 3 µW [20]. Per-channel: Pexp ≈ 0.43 µW. At B ≈ 1 MHz (limited by +subthreshold fT ): + 0.43 µW + Esub-VT = = 0.43 pJ. (S7.2) + 1 MHz +This is the most energy-efficient electrical approach, but at severely limited bandwidth (∼1 MHz). + h. Digital CMOS (for reference). A digital exponential via Taylor series requires ∼10 multiply-add operations. +Using Horowitz’s energy figures [21] for 45 nm at 0.9 V: 32-bit FP multiply costs 3.7 pJ, FP add costs 0.9 pJ, giving + Edigital ≈ 10 × (3.7 pJ + 0.9 pJ) = 46 pJ. (S7.3) +At 8-bit precision (sufficient for inference): ∼2.3 pJ. + + + S7.2 Photonic MRR cascade: single-channel energy derivation + + We evaluate the energy at N = 30 cascaded X-cut TFLN micro-ring resonators with R = 20 µm in the fabricated +high-Q regime (Qi = 106 , QL ≈ 25,200; Supplementary Sec. S5), which achieves εmax = 0.30% with Vctrl = 0.91 V +(fully CMOS-compatible). The energy per exponential operation has three components: + (i) Electro-optic tuning energy. Each ring is tuned by charging the arc electrode capacitance to Vctrl . For the lateral +S–G arc electrodes covering one semicircle (Larc = πR = 62.8 µm), the electrode capacitance is estimated as + Cel ≈ 18 fF, (S7.4) +based on coplanar electrode modeling for TFLN lateral S–G geometries with gel = 5 µm (comparable to values reported +by Bahadori et al. [22] for similar geometries). The switching energy per ring at Vctrl = 0.91 V (using the projected +QL = 25,200, which gives bV = 0.295 V−1 ): + 2 + Ering = 12 Cel Vctrl = 12 × 18 fF × (0.91 V)2 = 7.4 fJ. (S7.5) +For N = 30 rings: EEO = 30 × 7.4 = 0.22 pJ. + Note the important scaling: EEO ∝ 1/N since b ∝ 1/N from minimax optimization, because + 2 + EEO ∝ N × Vctrl ∝ N × (b/bV )2 ∝ 1/N. (S7.6) +The bias voltage (3.9 V) is static and does not contribute per-operation energy. + (ii) Laser source energy (amortized). Because every cascade channel uses the same fixed probe wavelength, a single +CW laser can be shared among M parallel softmax channels via a 1 × M optical power splitter. With wall-plug +efficiency ηWPE ≈ 15% [23], the per-channel optical power is Popt = Pin /M ≈ 100 µW (for Pin = 1 mW, M = 10), +requiring Plaser ≈ 667 µW per channel. At fmod = 10 GHz: Elaser = 667 µW / 10 GHz = 67 fJ. + (iii) Photodetector energy. Integrated SiGe photodetectors with TIA achieve sub-pJ reception [24]: EPD ≈ 0.5 pJ. + The total single-channel energy is + (1ch) + Ephotonic = EEO + Elaser + EPD = 0.22 + 0.07 + 0.50 = 0.79 pJ. (S7.7) + 31 + + S7.3 Q-factor scaling of energy efficiency + + 2 + Since Vctrl ∝ 1/Q and EEO ∝ Vctrl , the EO energy scales as 1/Q2 . Table S11 shows the total energy for N = 30 at +various quality factors. + +TABLE S11: Energy per exponential operation vs. quality factor (N = 30, εmax = 0.30%, X-cut arc electrode with bV + scales linearly with Q; Cel = 18 fF). Elaser + EPD = 0.57 pJ is the Q-independent floor. The dagger (†) marks the +FDTD-calibrated quality factor; the double dagger (‡) marks the high-Q design point (Qi = 106 ). Excludes thermal + stabilization (0.15–0.60 pJ for N = 30). + + Q Vctrl (V) Vbias (V) EEO (pJ) Etotal (pJ) + 5,000 4.57 19.5 5.64 6.21 + 10,000 2.28 9.7 1.40 1.97 + 12,500 1.83 7.8 0.90 1.47 +15,500† 1.47 6.3 0.58 1.15 + 20,000 1.14 4.9 0.35 0.92 +25,200‡ 0.91 3.9 0.22 0.79 + 30,000 0.76 3.2 0.16 0.73 + 50,000 0.46 1.9 0.06 0.63 + + + At QL = 15,500 (FDTD-calibrated), the EO contribution (0.58 pJ) is comparable to the optical floor, placing the +design in the efficient operating regime. Beyond Q ≈ 30,000, the EO contribution becomes negligible and the total +energy saturates near the floor; further Q improvement primarily benefits CMOS driver voltage compatibility rather +than energy. + i. Additional energy contributions. The estimates above exclude two further contributions: (i) DAC energy +for setting the per-ring control voltages, typically 0.1–1 pJ per conversion at 10 GHz bandwidth; and (ii) thermal +stabilization power for maintaining resonance alignment, estimated at ∼50–200 µW per ring for TFLN (lower than +silicon due to the small thermo-optic coefficient of LiNbO3 , dn/dT ≈ 3.9 × 10−6 K−1 ). At 10 GHz modulation rate, +the thermal contribution amounts to ∼0.005–0.02 pJ per ring per operation. For the N = 30 cascade, this sums to +0.15–0.60 pJ, which is comparable to EEO and must be included in the total: Etotal ≈ 0.94–1.39 pJ. The total energy +comparison should therefore be treated as an order-of-magnitude estimate. + + + S7.4 Comparison with electronic implementations + + Here we provide an order-of-magnitude energy comparison between electrical analog exponential circuits and our +photonic MRR cascade, grounding the analysis in published device data and first-principles estimates. We assume +a shared CW laser with total optical output Pin,tot = 1 mW, split across M = 10 parallel softmax channels via a +1 × M power splitter, yielding per-channel input Pin,ch = 100 µW. The output power at the cascade drop port is + N +Pout = Pin,ch × Dmax , which ranges from 0.61 µW (FDTD regime, N = 5) to 21.5 µW (fabricated regime, N = 30) +(Table V). + j. Electrical analog exponential circuits. Three main families of electrical analog exponential circuits are compared: +BJT translinear/Gilbert cell (∼ 3 pJ at 100 MHz [17–19]), CMOS subthreshold (∼ 0.43 pJ at 1 MHz [18, 20]), and +digital FP32 Taylor series (∼ 46 pJ at 1 GHz [21]). + k. Photonic MRR cascade: single-channel energy. For N = 30 X-cut TFLN micro-ring resonators in the self- +consistent high-Q regime (QL = 25,200), the three energy components are EO tuning (EEO = 0.22 pJ), amortized +laser (Elaser = 0.07 pJ, shared across M = 10 parallel channels), and photodetector (EPD = 0.50 pJ), yielding +Ephotonic = 0.79 pJ. Including thermal stabilization for N = 30 rings (0.15–0.60 pJ), the total rises to 0.94–1.39 pJ. +Notably, EEO ∝ 1/N since b ∝ 1/N from minimax optimization. + l. Single-channel comparison. Table S12 presents the comparison. The photonic cascade at N = 30 achieves +0.79 pJ baseline—3.8× lower than the BJT Gilbert cell (3 pJ) and 58× lower than digital FP32 (46 pJ). Including +thermal stabilization (0.94–1.39 pJ), the advantage over INT8 (2.3 pJ) is 1.7–2.4×, while operating at 10 GHz +bandwidth. At fabricated Q ≥ 30,000, EEO drops to 0.16 pJ and Etotal ≈ 0.73 pJ (excluding thermal; Table S11), +recovering a 3.2× advantage over INT8. Subthreshold CMOS achieves the lowest energy (0.43 pJ) but at 10,000× +lower bandwidth. + m. Caveats. These values are order-of-magnitude estimates, not device-accurate predictions. The photonic +estimate excludes DAC energy for voltage generation (typically 0.1–1 pJ per conversion at 10 GHz bandwidth, shared +with any analog approach) and thermal tuning power for maintaining resonance alignment (∼50–200 µW per ring for + 32 + + TABLE S12: Energy per exponential operation: single-channel comparison. + +Implementation E/op (pJ) Bandwidth Notes +Digital FP32 (Taylor) ∼46 1 GHz 10 FP MACs +BJT Gilbert cell ∼3 100 MHz Analog +Digital INT8 (Taylor) ∼2.3 1 GHz 10 INT MACs +Photonic MRR (N = 30) 0.94–1.39 10 GHz Analog† +Subthreshold CMOS ∼0.43 1 MHz Analog + † 0.79 pJ excluding thermal; 0.94–1.39 pJ including thermal. Self-consistent with fabricated high-Q regime (Q = 25,200); see + L + Supplementary Sec. S7. + + +TFLN, lower than silicon due to the small thermo-optic coefficient of LiNbO3 , dn/dT ≈ 3.9 × 10−6 K−1 ). Effective +precision at the photodetector is limited to ∼6–8 bits by shot noise and receiver electronics. The energy advantage +over electrical implementations is strongest in the fabricated high-Q regime (Dmax ≥ 0.95), where N = 30 is practical +and Vctrl remains CMOS-compatible. + 33 + + S8. MONTE CARLO ROBUSTNESS UNDER DEVICE NON-IDEALITIES + + This section describes the robustness model summarized in the main text. For the fitted L = 8, N = 10 design +(a = −1.4588, b = 0.10202), each Monte Carlo chip sample includes: (i) per-ring static detuning spread, (ii) per- +ring sensitivity spread, (iii) global thermal drift and crosstalk-like slope drift, (iv) stage insertion-loss variation, (v) +control-channel noise, and (vi) detector noise with one-point calibration at I = L. + For ring k, we use + 1 + Tk (I) = 2, (46) + 1 + (ak + bk I + dth + dxt I/L) + +with + N + Y + y(I) = Tk (I) × 10−ILtot /10 , (47) + k=1 + +and one-point calibration ỹ(I) = Ccal y(I) such that ỹ(L) = 1 for the same chip instance. + + TABLE S13: Non-ideality distributions used in the Monte Carlo sweeps. + + Parameter Nominal Stress + σa 0.020 0.032 + σb,rel 0.020 0.032 + σth 0.015 0.025 + σxt 0.012 0.020 + σI 0.004 0.007 + ILstage (dB, µ ± σ) 0.12 ± 0.03 0.18 ± 0.05 + σdet 3.0 × 10−6 6.0 × 10−6 + + + + TABLE S14: Monte Carlo summary (same run reported in main text). + + Metric Nominal Stress + Median KL(pref ∥papprox ) 2.17 × 10−4 7.39 × 10−4 + p95 KL(pref ∥papprox ) 5.92 × 10−4 2.21 × 10−3 + Median max |∆p| 0.170% 0.193% + p95 max |∆p| 0.319% 0.419% + +Conservative-bound sketch used for the main-text screening equation. For the identical-detuning family +with fixed b, define + + ln ỹ(I) = N ϕ(a + bI) − N ϕ(a + bL), ϕ(u) = − ln(1 + u2 ), (48) + +so that ỹ(L) = 1 by construction. Around a constructive choice with a + bI < 0 on [0, L], a second-order remainder +argument for the mismatch between the target slope and the fitted slope yields a term scaling as L2 /(4N ), while the +flank-curvature penalty contributes a term scaling as 1/(2b2 N ). Combining the two contributions gives the screening +inequality + + L2 1 + E∞ ≲ + 2 , (49) + 4N 2b N +which leads to the conservative screening equation reported in the main manuscript. We emphasize that this is a +conservative heuristic design rule (not a formal minimax theorem), used only for preliminary depth screening. + 34 + + + + +FIG. S4: CDF of end-to-end softmax probability error under the same non-ideality samples. + 35 + + S9. DELAY-AWARE FEEDBACK NORMALIZATION VALIDATION + + We model global normalization as a delayed PI-controlled loop: + + S(t) = G(t)P (t) + n(t), (50) + dP + τ = −P (t) + u(t − Td ), (51) + dt Z + u(t) = Kp e(t) + Ki e(t) dt, e(t) = Sref − S(t), (52) + +with actuator saturation 0 ≤ u ≤ Pmax . A piecewise G(t) profile is used to emulate workload changes. For physical +intuition, Table S15 converts normalized delay/settling metrics into absolute-time examples. + +TABLE S15: Example absolute-time interpretation of normalized PI-loop metrics using one representative stable case + ((Kp , Ki , Td /τ ) = (0.55, 0.8, 0.2)) and a ±2% settling-time definition (Tsettle ∼ 12.4τ ). + + Assumed τ Delay Td = 0.2τ Settling ∼ 12.4τ Interpretation + 100 ns 20 ns 1.24 µs fast loop + 1 µs 200 ns 12.4 µs moderate loop + 5 µs 1 µs 62 µs slower loop + +Reference-backed latency context for bottleneck screening. To place the delayed-loop times against mixed- +signal system latencies, Table S16 summarizes representative time scales with explicit path classes (on-chip vs off-chip) +for memory and interconnect paths, alongside conversion latency ranges. These are intentionally order-of-magnitude +ranges (not fixed constants), and can shift with architecture, clocking, and protocol stack choices. + + TABLE S16: Representative subsystem latency ranges used for conservative bottleneck screening in Sec. S9. + + Subsystem path Tsys Sources + On-chip memory (L1/L2) 20–200 ns [25] + Off-chip memory (DRAM) 200–700 ns [25, 26] + ADC conversion 10–710 ns [27, 28] + DAC + driver/settling 1–200 ns [29] + On-chip interconnect (NoC) 5–100 ns [30] + Off-chip I/O (PCIe/CXL) 1–10 µs [25, 31] + +Conservative risk-screening heuristic for loop latency. As a screening heuristic, we use the settling time from +one representative stable case ((Kp , Ki , Td /τ ) = (0.55, 0.8, 0.2); Table S18), with settling defined as the first time +entering and remaining within a ±2% band around Sref , as a normalization-loop latency proxy: + + Tnorm ≈ 12.4 τ. (53) + +This value is not a universal bound; different gain settings, delay ratios, or loop architectures will yield different settling +times. It is used only as a reference point for order-of-magnitude risk screening. Define the conservative screening +metric + + Tnorm ≥ β Tsys , (54) + +with β = 1 (high-risk screening line) and β = 0.5 (early-warning line); this is a heuristic risk indicator, not a formal +dominance proof. The corresponding threshold is + β Tsys + τcrit (β) = . (55) + 12.4 +Table S17 gives the resulting numeric ranges. +For the explicit examples in Table S15, τ = 0.1 µs gives Tnorm ≈ 1.24 µs, τ = 1 µs gives Tnorm ≈ 12.4 µs, and τ = 5 µs +gives Tnorm ≈ 62 µs. These numbers indicate a risk trend (not a hard boundary): for this representative case, the +normalization loop is typically non-dominant when τ is well below the relevant τcrit band, and it may become dominant + 36 + + TABLE S17: Computed τcrit ranges from Eq. (55) using Table S16. + + Subsystem Tsys range τcrit (β = 0.5) τcrit (β = 1) + On-chip memory path 20–200 ns 0.81–8.06 ns 1.61–16.13 ns + Off-chip memory path 200–700 ns 8.06–28.23 ns 16.13–56.45 ns + ADC conversion 10–710 ns 0.40–28.63 ns 0.81–57.26 ns + DAC+driver/settling 1–200 ns 0.04–8.06 ns 0.08–16.13 ns + On-chip interconnect (NoC) 5–100 ns 0.20–4.03 ns 0.40–8.06 ns + Off-chip I/O fabric 1–10 µs 0.04–0.40 µs 0.08–0.81 µs + + +as τ approaches or exceeds that band. The transition depends on path class (on-chip vs off-chip) and on architecture- +specific timing closure, including whether the normalization path lies on the end-to-end critical path (Table S16). +Accordingly, this analysis is intended for preliminary risk screening only; concrete implementations +require full timing validation. + +TABLE S18: Representative step-response cases for the delayed PI loop (settling defined by a ±2% band around Sref ). + + Case (Kp , Ki , Td /τ ) Overshoot Settling Stable + Stable (0.55, 0.8, 0.2) 25.6% ∼ 12.4τ Yes + Marginal (0.95, 1.6, 0.45) 25.6% ∼ 12.8τ Yes + Unstable (1.2, 2.2, 0.75) 45.1% not settled No + + + + TABLE S19: Stable-region fraction from gain-map scans at each delay ratio. + + Td /τ Stable fraction + 0.0 88.1% + 0.2 88.0% + 0.5 72.4% + 0.8 47.5% + 37 + + + + +FIG. S5: Step-response examples of the delayed PI normalization loop. + 38 + + + + +FIG. S6: Delay-dependent stability maps over scanned (Kp , Ki ) ranges. + 39 + + S10. REPRODUCIBILITY + + Scripts used for this Supplementary validation: + • scripts/nonideality montecarlo.py + + • scripts/feedback loop validation.py + + • scripts/extract logit range effective.py + + • scripts/analyze softmax clipping validity.py +Public code repository: https://github.com/hyoseokp/MRR-AEF (commit 585e695). Empirical extraction outputs +are stored under: + • paper/empirical L v3/ + + + + + [1] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia + Polosukhin. Attention is all you need. In Advances in Neural Information Processing Systems 30 (NeurIPS 2017), pages + 5998–6008, 2017. + [2] Alec Radford et al. Language models are unsupervised multitask learners. Technical report, OpenAI, 2019. + [3] Hugging Face. distilgpt2 model card, 2025. accessed 2026-02-21. + [4] Andrej Karpathy. Tiny shakespeare dataset (char-rnn), 2025. accessed 2026-02-21. + [5] Jane Austen. Pride and prejudice. Project Gutenberg eBook No. 1342, 2025. accessed 2026-02-21. + [6] Di Zhu et al. Integrated photonics on thin-film lithium niobate. Advances in Optics and Photonics, 13(2):242–352, 2021. + [7] Yaowen Hu, Di Zhu, Shengyuan Lu, Xinrui Zhu, Yunxiang Song, Dylan Renaud, Daniel Assumpcao, Rebecca Cheng, + CJ Xin, Matthew Yeh, Hana Warner, Xiangwen Guo, Amirhassan Shams-Ansari, David Barton, Neil Sinclair, and Marko + Loncar. Integrated electro-optics on thin-film lithium niobate. Nature Reviews Physics, 2025. + [8] Mian Zhang, Cheng Wang, Rebecca Cheng, Amirhassan Shams-Ansari, and Marko Lončar. Monolithic ultra-high-Q lithium + niobate microring resonator. Optica, 4(12):1536–1537, 2017. + [9] Renhong Gao, Ni Yao, Jianglin Guan, Li Deng, Jintian Lin, Min Wang, Lingling Qiao, Wei Fang, and Ya Cheng. Lithium + niobate microring with ultra-high Q factor above 108 . Chin. Opt. Lett., 20(1):011902, 2022. +[10] Rongjin Zhuang, Jinze He, Yifan Qi, and Yang Li. High-Q thin-film lithium niobate microrings fabricated with wet etching. + Adv. Mater., 35(3):2208113, 2023. +[11] Xinrui Zhu, Yaowen Hu, Shengyuan Lu, Hana K. Warner, Xudong Li, Yunxiang Song, Letı́cia S. Magalhães, Amirhassan + Shams-Ansari, Neil Sinclair, and Marko Lončar. Twenty-nine million intrinsic Q-factor monolithic microresonators on + thin-film lithium niobate. Photon. Res., 12(8):A63–A68, 2024. +[12] Sudip Shekhar, Wim Bogaerts, Lukas Chrostowski, John E. Bowers, Michael Hochberg, Richard Soref, and Bhavin J. + Shastri. Roadmapping the next generation of silicon photonics. Nature Communications, 15:751, 2024. +[13] Xaveer Leijtens et al. Multimode silicon photonics. Nanophotonics, 7:1571–1580, 2018. +[14] Haoqian Li et al. In-memory photonic dot-product engine with electrically programmable weight banks. Nature Communi- + cations, 14:2389, 2023. +[15] Daan Vermeulen et al. High-efficiency fiber-to-chip grating couplers realized using an advanced cmos-compatible silicon-on- + insulator platform. Optics Express, 18(17):18278–18283, 2010. +[16] F. S. Tan, D. J. W. Klunder, H. F. Bulthuis, G. Sengo, H. J. W. M. Hoekstra, and A. Driessen. Direct measurement of + the on-chip insertion loss of high finesse microring resonators in si3 n4 -sio2 technology. In Proceedings of the IEEE LEOS + Benelux Chapter, 2001. +[17] B. Gilbert. Translinear circuits: a proposed classification. Electron. Lett., 11(1):14–16, 1975. +[18] C. Mead. Analog VLSI and Neural Systems. Addison-Wesley, 1989. +[19] B. Razavi. Design of Analog CMOS Integrated Circuits. McGraw-Hill, 2 edition, 2017. +[20] Massimo Vatalaro, Tatiana Moposita, Sebastiano Strangio, Lionel Trojman, Andrei Vladimirescu, Marco Lanuzza, and + Felice Crupi. A low-voltage, low-power reconfigurable current-mode softmax circuit for analog neural networks. Electronics, + 10(9):1004, 2021. +[21] M. Horowitz. 1.1 computing’s energy problem (and what we can do about it). In 2014 IEEE International Solid-State + Circuits Conference (ISSCC), pages 10–14, 2014. +[22] Meisam Bahadori, Yansong Yang, Ahmed E. Hassanien, Lynford L. Goddard, and Songbin Gong. Ultra-efficient and fully + isotropic monolithic microring modulators in a thin-film lithium niobate photonics platform. Optics Express, 28(20):29644– + 29661, 2020. +[23] A. Biberman and K. Bergman. Optical interconnection networks for high-performance computing systems. Rep. Prog. + Phys., 75(4):046402, 2012. + 40 + +[24] D. A. B. Miller. Attojoule optoelectronics for low-energy information processing and communications. J. Lightwave Technol., + 35(3):346–396, 2017. +[25] Zhe Jia, Marco Maggioni, Benjamin Staiger, and Daniele P. Scarpazza. Dissecting the NVIDIA volta GPU architecture via + microbenchmarking. arXiv preprint arXiv:1804.06826, 2018. +[26] Yoongu Kim, Vivek Seshadri, Donghyuk Lee, Jamie Liu, and Onur Mutlu. A case for exploiting subarray-level parallelism + (SALP) in DRAM. In Proceedings of the 39th Annual International Symposium on Computer Architecture (ISCA), pages + 368–379, 2012. +[27] Texas Instruments. ADC12DJ3200: 6.4-GSPS single-channel or 3.2-GSPS dual-channel, 12-bit, RF-sampling analog-to-digital + converter. Datasheet (SLVSD97A, revised April 2020), 2020. Accessed 2026-02-22. +[28] Texas Instruments. ADS8881: 18-bit, 1-MSPS, low-power, true-differential SAR ADC. Datasheet (SBAS547D, revised + August 2015), 2015. Accessed 2026-02-22. +[29] Texas Instruments. DAC38RF82/DAC38RF89: Dual-channel, 14-bit, 9-GSPS and 6-GSPS RF DACs. Datasheet + (SLASEA6D, revised June 2020), 2020. Accessed 2026-02-22. +[30] W. J. Dally and B. Towles. Route packets, not wires: on-chip interconnection networks. In Proceedings of the 38th Design + Automation Conference (DAC), pages 684–689, 2001. +[31] Shintaro Sano, Yosuke Bando, Kazuhiro Hiwada, Hirotsugu Kajihara, Tomoya Suzuki, Yu Nakanishi, Daisuke Taki, and + Akiyuki Kaneko. Gpu graph processing on CXL-based microsecond-latency external memory. In Proceedings of the SC ’23 + Workshops of the International Conference for High Performance Computing, Networking, Storage and Analysis, 2023. +
\ No newline at end of file diff --git a/ep_run/factorized_exit.py b/ep_run/factorized_exit.py new file mode 100644 index 0000000..fbf66a8 --- /dev/null +++ b/ep_run/factorized_exit.py @@ -0,0 +1,330 @@ +"""Factorized BP-free exit feedback for local CE training. + +Replaces W_U^T(p-y) with α · C(p-y) @ U^T where: + C: fixed compressor (dense random or hybrid gold+topk+tail-sketch) + U: fixed orthonormal expander (d, r) + α: scalar gain + +Forward logits = h @ W_U^T (exact, unchanged) +grad_W = exact local CE gradient (no weight transport) +grad_h = factorized BP-free signal (no W_U^T) + +Two compressor modes: + dense: g @ C where C is (V, r) fixed random + hybrid: [gold + top-k exact codes, CountSketch(tail)] +""" +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def orthonormal_columns(d_out, rank, *, device=None, dtype=torch.float32, seed=None): + if rank <= 0 or rank > d_out: + raise ValueError(f"rank must satisfy 1 <= rank <= d_out, got rank={rank}, d_out={d_out}") + gen = torch.Generator(device="cpu") + if seed is not None: + gen.manual_seed(seed) + q, _ = torch.linalg.qr( + torch.randn(d_out, rank, dtype=dtype, generator=gen), mode="reduced" + ) + return q.contiguous().to(device=device) + + +class DenseRandomCompressor(nn.Module): + def __init__(self, vocab_size, rank, *, seed=None): + super().__init__() + self.vocab_size = vocab_size + self._rank = rank + gen = torch.Generator() + if seed is not None: + gen.manual_seed(seed) + codebook = torch.randn(vocab_size, rank, generator=gen) / math.sqrt(rank) + self.register_buffer("codebook", codebook) + + @property + def rank(self): + return self._rank + + @torch.no_grad() + def compress(self, grad_logits, targets): + return grad_logits.float() @ self.codebook.float() + + +class HybridTopKTailSketchCompressor(nn.Module): + def __init__(self, vocab_size, *, rank_exact=32, rank_tail=96, topk=8, seed=None): + super().__init__() + self.vocab_size = vocab_size + self.rank_exact = rank_exact + self.rank_tail = rank_tail + self.topk = topk + + gen = torch.Generator() + if seed is not None: + gen.manual_seed(seed) + + if rank_exact > 0: + codes = torch.randn(vocab_size, rank_exact, generator=gen) + codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-6) + else: + codes = torch.empty(vocab_size, 0) + self.register_buffer("exact_codes", codes) + + if rank_tail > 0: + bucket = torch.randint(0, rank_tail, (vocab_size,), generator=gen) + sign = torch.randint(0, 2, (vocab_size,), generator=gen).float() * 2 - 1 + else: + bucket = torch.empty(vocab_size, dtype=torch.long) + sign = torch.empty(vocab_size) + self.register_buffer("bucket", bucket) + self.register_buffer("sign", sign) + + @property + def rank(self): + return self.rank_exact + self.rank_tail + + @torch.no_grad() + def compress(self, grad_logits, targets): + g = grad_logits.float() + V = g.size(-1) + orig_shape = g.shape[:-1] + g_flat = g.reshape(-1, V) + t_flat = targets.reshape(-1).long() + N = g_flat.size(0) + device = g_flat.device + + safe_t = t_flat.clamp(min=0) + gold_grad = g_flat.gather(1, safe_t.unsqueeze(1)) # (N, 1) + + k_eff = min(self.topk, max(V - 1, 0)) + parts = [] + + # Exact head: gold + top-k + if self.rank_exact > 0: + codes = self.exact_codes.float() + gold_codes = codes[safe_t] # (N, r_exact) + c_exact = gold_grad * gold_codes + if k_eff > 0: + topv, topi = g_flat.topk(k_eff, dim=1) + top_codes = codes[topi] # (N, k, r_exact) + c_exact = c_exact + (topv.unsqueeze(-1) * top_codes).sum(dim=1) + parts.append(c_exact) + + # Tail CountSketch + if self.rank_tail > 0: + signed_full = g_flat * self.sign.unsqueeze(0) + c_tail = g_flat.new_zeros(N, self.rank_tail) + c_tail.scatter_add_(1, self.bucket.unsqueeze(0).expand(N, V), signed_full) + + # Remove gold contribution from tail + gold_bucket = self.bucket[safe_t] + gold_sign = self.sign[safe_t] + rows = torch.arange(N, device=device) + c_tail[rows, gold_bucket] -= gold_grad.squeeze(1) * gold_sign + + # Remove top-k from tail + if k_eff > 0: + top_bucket = self.bucket[topi] + top_sign_vals = self.sign[topi] + r_idx = rows.unsqueeze(1).expand(-1, k_eff).reshape(-1) + c_tail[r_idx, top_bucket.reshape(-1)] -= (topv * top_sign_vals).reshape(-1) + + parts.append(c_tail) + + c = torch.cat(parts, dim=1) + return c.reshape(*orig_shape, self.rank).to(dtype=grad_logits.dtype) + + +class _FactorizedExitFn(torch.autograd.Function): + @staticmethod + def forward(ctx, h, weight, targets, U, alpha, compressor): + logits = h @ weight.t() + ctx.compressor = compressor + ctx.save_for_backward(h.detach(), weight.detach(), targets, U.detach(), alpha.detach()) + ctx.logits_detached = logits.detach() + return logits + + @staticmethod + def backward(ctx, grad_logits): + h, weight, targets, U, alpha = ctx.saved_tensors + compressor = ctx.compressor + logits = ctx.logits_detached + + # Exact W gradient (no transport) + g_flat = grad_logits.reshape(-1, grad_logits.size(-1)).float() + h_flat = h.reshape(-1, h.size(-1)).float() + grad_weight = g_flat.t() @ h_flat + + # BP-free hidden signal via compressor + c = compressor.compress(grad_logits, targets).float() + grad_h = (alpha * c) @ U.float().t() + + return grad_h.to(h.dtype), grad_weight.to(weight.dtype), None, None, None, None + + +class _ExactParallelExitFn(torch.autograd.Function): + """Exit backward using only the exact recoverable parallel component. + + g_h_parallel = ((p-y)^T z / (||h||^2 + eps)) * h + + This is the ONLY component of W_U^T(p-y) that is identifiable from + forward quantities alone. The h-perp component is informationally + invisible without W_U. + """ + @staticmethod + def forward(ctx, h, weight, targets, residual_fn): + logits = h @ weight.t() + ctx.save_for_backward(h.detach(), weight.detach(), targets) + ctx.logits_detached = logits.detach() + ctx.residual_fn = residual_fn # optional h-perp residual + return logits + + @staticmethod + def backward(ctx, grad_logits): + h, weight, targets = ctx.saved_tensors + logits = ctx.logits_detached + residual_fn = ctx.residual_fn + + # Exact W gradient (no transport) + g_flat = grad_logits.reshape(-1, grad_logits.size(-1)).float() + h_flat = h.reshape(-1, h.size(-1)).float() + grad_weight = g_flat.t() @ h_flat + + # Exact parallel component: (p^T z - z_y) / (||h||^2 + eps) * h + # Memory-efficient: avoid materializing y_onehot (B,T,V) tensor. + # e = p - y_onehot computed in-place by subtracting 1 at target indices. + p = F.softmax(logits, dim=-1) # (..., V) + V = p.size(-1) + e = p # in-place: e will be (p - y_onehot) + target_idx = targets.clamp(min=0).unsqueeze(-1) + e.scatter_add_(-1, target_idx, torch.full_like(target_idx, -1.0, dtype=e.dtype)) + + # p^T z - z_y = (p-y)^T z (since p^T z - z_y = sum_j p_j z_j - z_y) + e_dot_z = (e * logits).sum(dim=-1, keepdim=True) # (..., 1) + h_norm_sq = (h.float() * h.float()).sum(dim=-1, keepdim=True) + 1e-8 # (..., 1) + + grad_h = (e_dot_z / h_norm_sq) * h.float() # (..., d) + + # Optional orthogonal residual + if residual_fn is not None: + residual = residual_fn(h.float(), e, logits, targets) + grad_h = grad_h + residual + + return grad_h.to(h.dtype), grad_weight.to(weight.dtype), None, None + + +class FactorizedExitHead(nn.Module): + """Drop-in BP-free local CE exit head.""" + + def __init__(self, d_model, vocab_size, *, mode="hybrid", rank=128, + rank_exact=32, topk=8, alpha_init=1.0, seed=None): + super().__init__() + if mode == "dense": + self.compressor = DenseRandomCompressor(vocab_size, rank, seed=seed) + elif mode == "hybrid": + self.compressor = HybridTopKTailSketchCompressor( + vocab_size, rank_exact=rank_exact, rank_tail=rank - rank_exact, topk=topk, seed=seed + ) + else: + raise ValueError(f"Unknown mode: {mode}") + + U = orthonormal_columns(d_model, self.compressor.rank, seed=seed) + self.register_buffer("U", U) + self.register_buffer("alpha", torch.tensor(alpha_init)) + self.vocab_size = vocab_size + + def forward(self, h, shared_weight, targets): + """h: (B,T,d), shared_weight: (V,d), targets: (B,T) → logits: (B,T,V)""" + return _FactorizedExitFn.apply(h, shared_weight, targets, self.U, self.alpha, self.compressor) + + +class ExactParallelExitHead(nn.Module): + """BP-free exit using exact parallel component + optional h-perp residual. + + Modes: + parallel_only: g̃_h = (e^T z / ||h||²) h (exact parallel only) + parallel_gold: + λ R(h) (e_y q_y) (+ gold token code in h⊥) + parallel_topmass: + λ R(h) (e_y q_y + Σ_{j∈S} e_j q_j) (+ top-mass codes in h⊥) + """ + + def __init__(self, d_model, vocab_size, *, mode="parallel_only", + residual_rank=32, residual_lambda=0.1, mass_threshold=0.95, seed=None): + super().__init__() + self.vocab_size = vocab_size + self.mode = mode + self.residual_lambda = residual_lambda + self.mass_threshold = mass_threshold + + if mode in ("parallel_gold", "parallel_topmass"): + # Fixed random token codes for h-perp residual + gen = torch.Generator() + if seed is not None: + gen.manual_seed(seed) + codes = torch.randn(vocab_size, residual_rank, generator=gen) + codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-6) + self.register_buffer("token_codes", codes) + + # Fixed base Q for constructing R(h) ∈ h⊥ + Q = orthonormal_columns(d_model, residual_rank, seed=seed) + self.register_buffer("Q_base", Q) + else: + self.token_codes = None + self.Q_base = None + + def _residual_fn(self, h, e, logits, targets): + """Compute h-perp residual: λ R(h) C_head(e).""" + if self.mode == "parallel_only" or self.token_codes is None: + return None + + B_T = h.shape[:-1] + d = h.size(-1) + device = h.device + + # R(h): project Q_base into h⊥ (memory-efficient: avoid materializing (N, d, r) tensor) + h_hat = h / (h.norm(dim=-1, keepdim=True) + 1e-8) # (..., d) + Q = self.Q_base.float() # (d, r) + hQ = (h_hat.unsqueeze(-2) @ Q).squeeze(-2) # (..., r) = h_hat^T Q per token + # Column norms of Q_bar = sqrt(1 - hQ_j^2) (since Q cols are unit-norm, h_hat unit-norm) + col_norm_sq = (1.0 - hQ ** 2).clamp_min(1e-8) # (..., r) + col_norm_inv = col_norm_sq.rsqrt() # (..., r) + + # C_head(e): gold + (optionally) top-mass codes + t_flat = targets.reshape(-1).clamp(min=0) + e_flat = e.reshape(-1, self.vocab_size) + N = e_flat.size(0) + + codes = self.token_codes.float() + gold_grad = e_flat.gather(1, t_flat.unsqueeze(1)) # (N, 1) + c = gold_grad * codes[t_flat] # (N, r) + + if self.mode == "parallel_topmass": + # Adaptive top-mass via topk(200) + cumulative mass (avoids full sort OOM) + p_flat = F.softmax(logits.reshape(-1, self.vocab_size).float(), dim=-1) + k_pre = min(200, self.vocab_size - 1) + top_p, top_idx = p_flat.topk(k_pre, dim=1) # (N, k_pre) + top_p_cumsum = top_p.cumsum(dim=-1) + keep_mask = top_p_cumsum <= self.mass_threshold + keep_mask[:, 0] = True + # Get corresponding error values and codes (chunked to avoid OOM on codes[top_idx]) + top_e = e_flat.gather(1, top_idx) # (N, k_pre) + top_e_masked = top_e * keep_mask.float() + chunk_size = min(1024, N) + for cs in range(0, N, chunk_size): + ce = min(cs + chunk_size, N) + chunk_codes = codes[top_idx[cs:ce]] # (chunk, k_pre, r) + c[cs:ce] += (top_e_masked[cs:ce].unsqueeze(-1) * chunk_codes).sum(dim=1) + + c = c.reshape(*B_T, -1) # (..., r) + + # R(h) @ c: project into h⊥ (memory-efficient, O(N*(d+r)) instead of O(N*d*r)) + # residual = Σ_j (c_j / ||Q_bar_j||) * (Q_j - h_hat * hQ_j) + # = Q @ c_adj - h_hat * (c_adj · hQ) + c_adj = c * col_norm_inv # (..., r) + residual = c_adj @ Q.t() - h_hat * (c_adj * hQ).sum(dim=-1, keepdim=True) + return self.residual_lambda * residual + + def forward(self, h, shared_weight, targets): + residual_fn = self._residual_fn if self.mode != "parallel_only" else None + return _ExactParallelExitFn.apply(h, shared_weight, targets, residual_fn) diff --git a/ep_run/fast_probe.py b/ep_run/fast_probe.py new file mode 100644 index 0000000..6fa72a2 --- /dev/null +++ b/ep_run/fast_probe.py @@ -0,0 +1,41 @@ +"""Validate the speed knobs: corr_every (stale AEP corr) x tf32, vs fp32 exact reference. +Watch BOTH gradient cosine AND the achievable free-phase residual (tf32 may raise the res floor +above res_est=1e-4 -> validity issue).""" +import time, torch +from lt_ep_train import EQBlock, get_batch, bptt_step, relax +from holo_ep import holo_a_select +torch.manual_seed(0) +B, T = 16, 64 +blk = EQBlock(128, 4, 256, T, attn_mode='thick') +for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')): + with torch.no_grad(): + p.copy_(w.to('cuda')) + +def cos(ga, gb, ps): + keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None] + va = torch.cat([ga[id(p)].reshape(-1) for p in keep]); vb = torch.cat([gb[id(p)].reshape(-1) for p in keep]) + return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() + +def gfrom(idx, zs, a): + with torch.enable_grad(): + xin = blk.embed(idx) + f = blk.force(zs.detach(), xin, cg=True) + g = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True) + return {id(p): gv for p, gv in zip(blk.block, g)} + +print(f"{'tf32':>5} {'corr_ev':>8} {'res@400':>9} {'t_best':>7} {'cos':>6} {'sec':>6}") +for bi in range(2): + idx, y = get_batch('train', B, T) + torch.backends.cuda.matmul.allow_tf32 = False + ref = bptt_step(blk, idx, y, 400, 0.1) + for tf32 in (False, True): + torch.backends.cuda.matmul.allow_tf32 = tf32 + torch.backends.cudnn.allow_tf32 = tf32 + xin = blk.embed(idx).detach() + zs = relax(blk, xin.clone(), xin, 400, 0.1) + res = (relax(blk, zs, xin, 1, 0.1) - zs).norm().item() / zs.norm().item() + for ck in (1, 2, 3): + t0 = time.time() + a, tb = holo_a_select(blk, zs, xin, y, 2, 0.02, 120, 0.1, corr_every=ck) + dt = time.time() - t0 + print(f"{str(tf32):>5} {ck:>8} {res:>9.1e} {tb:>7} {cos(gfrom(idx, zs, a), ref, blk.block):>6.3f} {dt:>6.1f}", flush=True) diff --git a/ep_run/gcalib.py b/ep_run/gcalib.py new file mode 100644 index 0000000..85495c0 --- /dev/null +++ b/ep_run/gcalib.py @@ -0,0 +1,38 @@ +"""EP lr theory, step 1: measure k = |g_EP|/|g_BPTT| per param group at a realistic operating point. +Native reference is BPTT (Ernoult: EP=BPTT as beta->0, converged) — NOT BP. lr_EP = lr_BPTT / k. +Report magnitude ratio AND cosine (direction) per group so we separate scale (k) from alignment.""" +import torch +import lt_ep_train as M +from pathlib import Path +import pickle +M.DD = Path('/tmp/lt_ep/data/tinystories_bpe'); M.vocab = pickle.load(open(M.DD/'meta.pkl','rb'))['vocab_size'] +from lt_ep_train import EQBlock, get_batch, bptt_step, ep_step +torch.manual_seed(0) +C,H,T,B = 512, 16, 256, 16 +blk = EQBlock(C,H,256,T,attn_mode='thick'); blk.qknorm=True; blk.track=False; blk.li_avg=0; blk.navg=1; blk.fnoise=0; blk.nbrake=0; blk._cstep=None +with torch.no_grad(): blk.WO.mul_(0.1); blk.pj.mul_(0.1) +opt = torch.optim.AdamW(blk.allp, lr=5e-4, weight_decay=1e-4) +for _ in range(300): # pretrain to a realistic operating point (BPTT) + idx,y = get_batch('train',B,T); g = bptt_step(blk,idx,y,150,0.1) + opt.zero_grad(set_to_none=True) + for p in blk.allp: p.grad = g.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp,5.0); opt.step() +print("pretrained 300 BPTT steps (C=512). k=|g_EP|/|g_BPTT|, cos=direction:", flush=True) +groups = {'all':blk.block,'attn':[blk.WQ,blk.WK,blk.WV,blk.WO],'ffn':[blk.fc,blk.fcb,blk.pj,blk.pjb], + 'ln':[blk.ln1g,blk.ln1b,blk.ln2g,blk.ln2b],'emb':[blk.tok,blk.pos]} +def cat(g,ps): + v=[g[id(p)].reshape(-1) for p in ps if g.get(id(p)) is not None]; return torch.cat(v) if v else None +import numpy as np +acc={k:[] for k in groups}; accc={k:[] for k in groups} +for _ in range(6): + idx,y = get_batch('train',B,T) + gE,_ = ep_step(blk,idx,y,150,20,0.1,0.02,0.0,holo=2,hr=0.02,t1max=500,res_est=1e-4,t2sel=120) + gB = bptt_step(blk,idx,y,400,0.1) + for k,ps in groups.items(): + a,b = cat(gE,ps),cat(gB,ps) + if a is not None and b is not None: + acc[k].append((a.norm()/(b.norm()+1e-12)).item()) + accc[k].append((a@b/(a.norm()*b.norm()+1e-12)).item()) +print(f"{'group':>5} {'k=|gEP|/|gBPTT|':>16} {'cos':>6} -> lr_EP = lr_BPTT / k") +for k in groups: + print(f"{k:>5} {np.mean(acc[k]):>16.3f} {np.mean(accc[k]):>6.3f}", flush=True) diff --git a/ep_run/gen_ept.py b/ep_run/gen_ept.py new file mode 100644 index 0000000..2a59ff5 --- /dev/null +++ b/ep_run/gen_ept.py @@ -0,0 +1,32 @@ +import torch, math, torch.nn.functional as F +from pathlib import Path +from tokenizers import Tokenizer +import lt_ep_train as LT +dev='cuda' +DD=Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe') +tok=Tokenizer.from_file(str(DD/'tokenizer.json')) +C,H,Mm,T,c,T1,eps = 512,16,256,256,1.0,150,0.1 +blk=LT.EQBlock(C,H,Mm,T,c=c,attn_mode='thick') +ck=torch.load('/home/yurenh2/ept/ep_run/runs/ep_resreg_warm.pt',map_location=dev) +params = ck['pema'] if ck.get('pema') is not None else ck['allp'] +with torch.no_grad(): + for p,s in zip(blk.allp, params): p.copy_(s.to(dev)) +print(f"loaded resreg_warm pema | best val CE {ck['best']:.4f} | step {ck['step']}", flush=True) +@torch.no_grad() +def gen(prompt, n_new=150, temp=0.8, topk=40, seed=0): + torch.manual_seed(seed) + ids=tok.encode(prompt).ids + idx=torch.zeros(1,T,dtype=torch.long,device=dev); L=len(ids) + idx[0,:L]=torch.tensor(ids,device=dev) + for _ in range(n_new): + if L>=T: break + xin=blk.embed(idx) + z=LT.relax(blk, xin.clone(), xin, T1, eps) + lg=(z@blk.Wh)[0,L-1]/temp + v,_=torch.topk(lg,topk); lg[lg<v[-1]]=-float('inf') + p=F.softmax(lg,-1); nt=torch.multinomial(p,1).item() + idx[0,L]=nt; L+=1 + return tok.decode(idx[0,:L].tolist()) +for seed in range(3, 11): + print(f"\n===== seed={seed} temp=0.7 =====", flush=True) + print(gen("Once upon a time,", 135, temp=0.7, topk=40, seed=seed), flush=True) diff --git a/ep_run/grad_quality.py b/ep_run/grad_quality.py new file mode 100644 index 0000000..095d764 --- /dev/null +++ b/ep_run/grad_quality.py @@ -0,0 +1,64 @@ +"""Gradient-quality probe at a realistic operating point. Pretrain the thick block with BPTT +for 300 steps (lands near where good optima live, res ~1e-2-1e-3, NO contraction penalty), +then measure cosine(EP gradient, long-horizon-BPTT reference) per parameter group as a +function of free-phase length T1 (= residual level) and nudge length T2. +If cosine is high at res ~1e-2 -> there is NO estimator wall at the operating points that +matter; the EP-vs-BPTT gap is speed + regularization tax, not a convergence mechanism.""" +import torch +from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +B, T, C, H = 16, 64, 128, 4 +blk = EQBlock(C, H, 256, T, attn_mode='thick') # c=1.0 default = thick-BPTT baseline setting +import os +if os.path.exists('/tmp/lt_ep/probe_w.pt'): + for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')): + with torch.no_grad(): + p.copy_(w.to(dev)) + print("loaded cached pretrained weights", flush=True) +else: + opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) + for step in range(300): + idx, y = get_batch('train', B, T) + g = bptt_step(blk, idx, y, 150, 0.1) + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = g.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) + opt.step() + torch.save([p.detach().cpu() for p in blk.allp], '/tmp/lt_ep/probe_w.pt') +print("pretrained 300 BPTT steps (thick, c=1) -- measuring at this operating point", flush=True) + +groups = {'all': blk.block, + 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], + 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], + 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b], + 'emb': [blk.tok, blk.pos], + 'mem': [blk.Wm]} + + +def cos(ga, gb, ps): + keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None] + if not keep: + na = sum(1 for p in ps if ga.get(id(p)) is None) + nb = sum(1 for p in ps if gb.get(id(p)) is None) + print(f"\n[debug] empty keep: |ps|={len(ps)} ga_None={na} gb_None={nb} " + f"ga_keys={len(ga)} gb_keys={len(gb)}", flush=True) + return float('nan') + va = torch.cat([ga[id(p)].reshape(-1) for p in keep]) + vb = torch.cat([gb[id(p)].reshape(-1) for p in keep]) + return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() + + +hdr = f"{'config':>16} {'res':>9} " + " ".join(f"{k:>6}" for k in groups) +for bi in range(3): + idx, y = get_batch('train', B, T) + ref = bptt_step(blk, idx, y, 400, 0.1) # exact reference: long-horizon BPTT + print(("\n" if bi else "") + hdr, flush=True) + g150 = bptt_step(blk, idx, y, 150, 0.1) + print(f"{'bptt T1=150':>16} {'--':>9} " + " ".join(f"{cos(g150, ref, ps):>6.3f}" for ps in groups.values()), flush=True) + for T1 in (50, 150, 400): + gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0) + print(f"{f'ep T1={T1:<3} T2=20':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True) + gep, res = ep_step(blk, idx, y, 150, 60, 0.1, 0.02, 0.0) + print(f"{'ep T1=150 T2=60':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True) diff --git a/ep_run/holo_ep.py b/ep_run/holo_ep.py new file mode 100644 index 0000000..31054e4 --- /dev/null +++ b/ep_run/holo_ep.py @@ -0,0 +1,332 @@ +"""Holomorphic EP (Laborieux & Zenke 2022) for the non-conservative thick block. + +Plain EP estimates a = -dz*/dbeta with a 2-point real centered difference: bias O(beta^2) forces +beta small (0.02), and the estimator noise scales like (equilibration error)/beta. Holomorphic EP +evaluates the nudged equilibrium at N points on a CIRCLE |beta|=r in the complex plane and reads +-dz*/dbeta off the discrete Cauchy/Fourier formula a = -Re[(1/(N r)) sum_k e^{-i phi_k} z*(r e^{i phi_k})]: +bias O(r^N) instead of O(r^2) -> r can be 5-10x larger at equal bias -> the 1/beta noise +amplification drops by the same factor. Requires the force holomorphically extended to complex +state: manual LN (non-conjugate variance), softmax (exp ratio), GELU (tanh form, entire). +The AEP correction carries over unchanged: it is linear in (z - z*) with REAL coefficients, so it +preserves holomorphy in beta; apply it to real and imaginary parts separately. +NOTE: no g-clamp and no corr-clip inside the holomorphic nudge (clamps are non-analytic and would +destroy the O(r^N) bias property); we monitor max|z-z*| instead.""" +import math, torch, torch.nn.functional as F +from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step, relax + +CDT = torch.complex64 + + +def cln(z, g, b, eps=1e-5): # holomorphic LayerNorm: NON-conjugate variance + mu = z.mean(-1, keepdim=True) + v = ((z - mu) ** 2).mean(-1, keepdim=True) # analytic continuation of the real LN + return (z - mu) / torch.sqrt(v + eps) * g + b + + +def csoftmax_masked(a, mask): # holomorphic causal softmax via exp ratio + c = a.real.amax(-1, keepdim=True) # constant row shift cancels exactly in the ratio + w = torch.exp(a - c) * mask # masked entries -> exact 0 + return w / w.sum(-1, keepdim=True) + + +def cgelu(z): # tanh-form GELU: entire function + return 0.5 * z * (1.0 + torch.tanh(0.7978845608028654 * (z + 0.044715 * z ** 3))) + + +def cforce(blk, z, xin): # holomorphic extension of the thick force + C, H, dh, T = blk.C, blk.H, blk.dh, blk.T + B = z.size(0) + h1 = cln(z, blk.ln1g.to(CDT), blk.ln1b.to(CDT)) + h2 = cln(z, blk.ln2g.to(CDT), blk.ln2b.to(CDT)) + q = (h1 @ blk.WQ.to(CDT)).view(B, T, H, dh).transpose(1, 2) + k = (h1 @ blk.WK.to(CDT)).view(B, T, H, dh).transpose(1, 2) + v = (h1 @ blk.WV.to(CDT)).view(B, T, H, dh).transpose(1, 2) + if getattr(blk, 'qknorm', False): # match attn()'s q/k RMSNorm (holomorphic: non-conjugate q^2) + q = q * (q.pow(2).mean(-1, keepdim=True) + 1e-6).pow(-0.5) + k = k * (k.pow(2).mean(-1, keepdim=True) + 1e-6).pow(-0.5) + a = (q @ k.transpose(-2, -1)) / math.sqrt(dh) + p = csoftmax_masked(a, blk.cmask.to(CDT)) + att = (p @ v).transpose(1, 2).reshape(B, T, C) @ blk.WO.to(CDT) + ff = cgelu(h2 @ blk.fc.to(CDT) + blk.fcb.to(CDT)) @ blk.pj.to(CDT) + blk.pjb.to(CDT) + return -(z - xin) + att + ff - blk.c * z + + +def cgrad_ce(blk, z, y): # holomorphic dCE/dz = (softmax(z Wh) - Y) Wh^T / NT + logits = z @ blk.Wh.to(CDT) + c = logits.real.amax(-1, keepdim=True) + w = torch.exp(logits - c) + p = w / w.sum(-1, keepdim=True) + Y = F.one_hot(y, p.size(-1)).to(CDT) + return (p - Y) @ blk.Wh.t().to(CDT) / y.numel() + + +def holo_a(blk, zs, xin, y, N, r, T2, eps, corr_on=True): + """Nudged phases at beta_k = r e^{2 pi i k / N}; returns (a, max|z - z*|) with + a = -Re[(1/(N r)) sum_k e^{-i phi_k} (z_k - z*)] ~ -dz*/dbeta + O(r^N).""" + zsc, xc = zs.to(CDT), xin.to(CDT) + acc = torch.zeros_like(zsc) + mg = 0.0 + for kk in range(N): + ph = complex(math.cos(2 * math.pi * kk / N), math.sin(2 * math.pi * kk / N)) + beta = r * ph + z = zsc.clone() + for _ in range(T2): + with torch.no_grad(): + f = cforce(blk, z, xc) - beta * cgrad_ce(blk, z, y) + if corr_on: # AEP: J -> J^T, linear & real -> holomorphy kept + v = z - zsc + Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v.real.contiguous())[1] + 0j + JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v.real.contiguous())[1] + 0j + if v.imag.abs().max() > 1e-9: # real-axis phases skip the imag solves + Jv = Jv + 1j * torch.autograd.functional.jvp(blk.nc_force, zs, v.imag.contiguous())[1] + JTv = JTv + 1j * torch.autograd.functional.vjp(blk.nc_force, zs, v.imag.contiguous())[1] + f = f - (Jv - JTv) + z = z + eps * f + acc = acc + torch.conj(torch.tensor(ph, device=z.device)) * (z - zsc) + mg = max(mg, (z - zsc).abs().max().item()) + a = -(acc / (N * r)).real + return a.detach(), mg + + +def holo_a_select(blk, zs, xin, y, N, r, T2max, eps, K=10, exit_mult=5.0, corr_every=1): + """Adaptive-T2 by hindsight selection: run nudged phases in lockstep to T2max, snapshot the + contrast a_t every K steps, return the snapshot with the smallest increment (most settled). + Never worse than short fixed T2 (the settled snapshot exists early too); captures the long-T2 + win (cos up to ~0.99) when the nudged dynamics are stable; early-exits only on clear blowup — + judging by increments of the QUANTITY OF INTEREST, not step sizes, so non-normal transient + growth cannot trigger a premature stop.""" + zsc, xc = zs.to(CDT), xin.to(CDT) + ph = [complex(math.cos(2 * math.pi * k / N), math.sin(2 * math.pi * k / N)) for k in range(N)] + Z = [zsc.clone() for _ in range(N)] + corr = [None] * N + a_prev = a_best = None + inc_min, t_best = float('inf'), 0 + for t in range(1, T2max + 1): + for k in range(N): + with torch.no_grad(): + f = cforce(blk, Z[k], xc) - (r * ph[k]) * cgrad_ce(blk, Z[k], y) + if corr[k] is None or (t - 1) % corr_every == 0: # v moves ~eps/step: stale corr is cheap + v = Z[k] - zsc + Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v.real.contiguous())[1] + 0j + JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v.real.contiguous())[1] + 0j + if v.imag.abs().max() > 1e-9: + Jv = Jv + 1j * torch.autograd.functional.jvp(blk.nc_force, zs, v.imag.contiguous())[1] + JTv = JTv + 1j * torch.autograd.functional.vjp(blk.nc_force, zs, v.imag.contiguous())[1] + corr[k] = Jv - JTv + Z[k] = Z[k] + eps * (f - corr[k]) + if t % K == 0 or t == T2max: + acc = sum(torch.conj(torch.tensor(p, device=zs.device)) * (zk - zsc) for p, zk in zip(ph, Z)) + a_t = -(acc / (N * r)).real + if not torch.isfinite(a_t).all(): + break + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + elif inc > exit_mult * inc_min and t >= 3 * K: + break + a_prev = a_t + if a_best is None: + a_best, t_best = a_prev, T2max + return a_best.detach(), t_best + + +def rforce(blk, z, xin): # real-axis twin of cforce (tanh-gelu, clamp-free) + C, H, dh, T = blk.C, blk.H, blk.dh, blk.T + B = z.size(0) + h1 = F.layer_norm(z, (C,), blk.ln1g, blk.ln1b) + h2 = F.layer_norm(z, (C,), blk.ln2g, blk.ln2b) + q = (h1 @ blk.WQ).view(B, T, H, dh).transpose(1, 2) + k = (h1 @ blk.WK).view(B, T, H, dh).transpose(1, 2) + v = (h1 @ blk.WV).view(B, T, H, dh).transpose(1, 2) + if getattr(blk, 'qknorm', False): # match attn()'s q/k RMSNorm in the nudge force + q = q * torch.rsqrt(q.pow(2).mean(-1, keepdim=True) + 1e-6) + k = k * torch.rsqrt(k.pow(2).mean(-1, keepdim=True) + 1e-6) + a = (q @ k.transpose(-2, -1)) / math.sqrt(dh) + p = torch.softmax(a.masked_fill(~blk.cmask, float('-inf')), -1) + att = (p @ v).transpose(1, 2).reshape(B, T, C) @ blk.WO + ff = cgelu(h2 @ blk.fc + blk.fcb) @ blk.pj + blk.pjb + nc = att + ff + if getattr(blk, 'fnoise', 0.0) > 0: + nc = nc * (1 + blk.fnoise * torch.randn_like(nc)) + return -(z - xin) + nc - blk.c * z + + +def rgrad_ce(blk, z, y, denom=None): + p = torch.softmax(z @ blk.Wh, -1) + return (p - F.one_hot(y, p.size(-1)).to(z.dtype)) @ blk.Wh.t() / (denom or y.numel()) + + +def holo_a_select2(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0, li=0): + """li>0 enables LOCK-IN INTEGRATION mode for noisy (hardware) physics: run the full T2max, + EMA the contrast a_t every step with time-constant li — the homodyne integrator that divides + persistent per-pass noise by sqrt(window). The hindsight-selection mode (li=0) is for clean + physics, where a single most-settled snapshot is optimal.""" + """N=2 production fast path — mathematically identical to holo_a_select(N=2): both phases are + real, so run them PHASE-BATCHED (stack +r/-r along batch) with real tensors and torch.func + forward-mode jvp. Halves autograd calls and skips complex arithmetic.""" + import torch.func as tf + B = zs.size(0) + Z = torch.cat([zs, zs], 0) # [+r phase | -r phase] + X2 = torch.cat([xin, xin], 0) + y2 = torch.cat([y, y], 0) + sg = torch.cat([torch.full((B, 1, 1), r, device=zs.device), + torch.full((B, 1, 1), -r, device=zs.device)], 0) + zs2 = torch.cat([zs, zs], 0) + fnc = lambda zz: blk.nc_force(zz) + a_prev = a_best = a_ema = None + inc_min, t_best = float('inf'), 0 + for t in range(1, T2max + 1): + with torch.no_grad(): + f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel()) # CE mean over the ORIGINAL batch + v = (Z - zs2).contiguous() + _, Jv = tf.jvp(fnc, (zs2,), (v,)) + JTv = tf.vjp(fnc, zs2)[1](v)[0] + Z = Z + eps * (f - (Jv - JTv)) + if li > 0: # lock-in integration (noisy physics) + a_t = (Z[B:] - Z[:B]) / (2 * r) + if not torch.isfinite(a_t).all(): + break + if t > T2max // 3: # let phases develop, then integrate + a_ema = a_t if a_ema is None else a_ema + (a_t - a_ema) / li + continue + if t % K == 0 or t == T2max: + a_t = (Z[B:] - Z[:B]) / (2 * r) # (z_- - z_+)/2r + if not torch.isfinite(a_t).all(): + break + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + elif inc > exit_mult * inc_min and t >= 3 * K: + break + a_prev = a_t + if li > 0: + if a_ema is None: + a_ema = (Z[B:] - Z[:B]) / (2 * r) + return a_ema.detach(), T2max + if a_best is None: + a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r) + t_best = T2max + return a_best.detach(), t_best + + +def holo_a_track(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0): + """Common-mode-tracking AEP: linearize the antisymmetric correction at the instantaneous + common mode of the two phases — exact transposed differential dynamics, loose-tolerant, + no compounding linearization error.""" + import torch.func as tf + B = zs.size(0) + Z = torch.cat([zs, zs], 0) + X2 = torch.cat([xin, xin], 0) + y2 = torch.cat([y, y], 0) + sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) + fnc = lambda zz: blk.nc_force(zz) + a_prev = a_best = None + inc_min, t_best = float('inf'), 0 + zs2a = torch.cat([zs, zs], 0) + kappa = getattr(blk, 'nbrake', 0.0) + for t in range(1, T2max + 1): + with torch.no_grad(): + zbar = 0.5 * (Z[:B] + Z[B:]) + zb2 = torch.cat([zbar, zbar], 0) + f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel()) + if kappa > 0: # measurement brake: Tikhonov-regularized adjoint + f = f - kappa * (Z - zs2a) + v = (Z - zb2).contiguous() + _, Jv = tf.jvp(fnc, (zb2,), (v,)) + JTv = tf.vjp(fnc, zb2)[1](v)[0] + Z = Z + eps * (f - (Jv - JTv)) + if t % K == 0 or t == T2max: + a_t = (Z[B:] - Z[:B]) / (2 * r) + if not torch.isfinite(a_t).all(): + break + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + elif inc > exit_mult * inc_min and t >= 3 * K: + break + a_prev = a_t + if a_best is None: + a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r) + t_best = T2max + return a_best.detach(), t_best + + +def holo_a_lockin(blk, zs, xin, y, r, P, ncyc, eps): + """True oscillatory EP / lock-in estimator (Laborieux–Zenke taken literally) — the + noisy-physics form: ONE trajectory, sinusoidal nudge beta(t)=r·sin(2πt/P), in-phase + demodulation over ncyc periods (first period discarded as transient). Single-trajectory => + common-mode noise cancels in the quadrature; v=z−z* stays O(r·response) so the AEP + linearization never leaves its window; noise admitted only in the demodulation band.""" + z = zs.clone() + accI = torch.zeros_like(zs) + sI = 0.0 + T = P * (ncyc + 1) + for t in range(1, T + 1): + s = math.sin(2 * math.pi * t / P) + with torch.no_grad(): + f = rforce(blk, z, xin) - (r * s) * rgrad_ce(blk, z, y, denom=y.numel()) + v = (z - zs).contiguous() + Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v)[1] + JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v)[1] + z = z + eps * (f - (Jv - JTv)) + if not torch.isfinite(z).all(): + return None, t + if t > P: # demodulate after the transient period + accI = accI + z * s + sI += s * s + return (-(accI / (sI + 1e-12)) / r).detach(), T + """Full holomorphic-EP gradient for block params (same VF readout as ep_step).""" + xin0 = blk.embed(idx).detach() + zs = relax(blk, xin0.clone(), xin0, T1, eps) + res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) + a, mg = holo_a(blk, zs, xin0, y, N, r, T2, eps) + with torch.enable_grad(): + xin = blk.embed(idx) + f = blk.force(zs.detach(), xin, cg=True) + gblk = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True) + return {id(p): g for p, g in zip(blk.block, gblk)}, res, mg + + +if __name__ == '__main__': + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + torch.manual_seed(0) + B, T, C, H = 16, 64, 128, 4 + blk = EQBlock(C, H, 256, T, attn_mode='thick') + for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')): + with torch.no_grad(): + p.copy_(w.to(dev)) + print("loaded probe weights (300-step BPTT, thick, c=1)", flush=True) + + groups = {'all': blk.block, + 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], + 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], + 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b], + 'emb': [blk.tok, blk.pos]} + + def cos(ga, gb, ps): + keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None] + if not keep: + return float('nan') + va = torch.cat([ga[id(p)].reshape(-1) for p in keep]) + vb = torch.cat([gb[id(p)].reshape(-1) for p in keep]) + return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() + + hdr = f"{'estimator':>22} {'res':>9} {'max|dz|':>8} " + " ".join(f"{k:>6}" for k in groups) + for bi in range(3): + idx, y = get_batch('train', B, T) + ref = bptt_step(blk, idx, y, 400, 0.1) + print(("\n" if bi else "") + hdr, flush=True) + for T1 in (150, 400): + gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0) + print(f"{f'plain ep b=.02 T1={T1}':>22} {res:>9.1e} {'--':>8} " + + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True) + gep2, _ = ep_step(blk, idx, y, T1, 20, 0.1, 0.1, 0.0) + print(f"{f'plain ep b=.10 T1={T1}':>22} {res:>9.1e} {'--':>8} " + + " ".join(f"{cos(gep2, ref, ps):>6.3f}" for ps in groups.values()), flush=True) + for (N, r) in ((2, 0.02), (4, 0.05), (4, 0.1), (4, 0.2), (8, 0.2)): + gh, res2, mg = holo_grads(blk, idx, y, T1, 20, 0.1, N, r) + print(f"{f'holo N={N} r={r} T1={T1}':>22} {res2:>9.1e} {mg:>8.2f} " + + " ".join(f"{cos(gh, ref, ps):>6.3f}" for ps in groups.values()), flush=True) diff --git a/ep_run/jnc_scaling.py b/ep_run/jnc_scaling.py new file mode 100644 index 0000000..2126d9a --- /dev/null +++ b/ep_run/jnc_scaling.py @@ -0,0 +1,46 @@ +"""Causal probe for the stability onset: measure the non-conservative Jacobian norm ‖J_nc(z*)‖ +(Hutchinson, the same quantity the jacreg controller penalizes) vs width C, at init and after a +few training steps. If ‖J_nc‖ growth-per-lr-step crosses the contraction margin at the measured +critical scale, the lr_crit onset is DERIVED from dynamics, not a hyperparameter.""" +import math, torch +import lt_ep_train as M +from pathlib import Path +import pickle +M.DD = Path('/tmp/lt_ep/data/tinystories_bpe') +M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size'] +from lt_ep_train import EQBlock, get_batch, bptt_step, relax +dev = 'cuda' + +def jnc_norm(blk, zs, n=8): # Hutchinson estimate of ‖J_nc(z*)‖_F + tot = 0.0 + for _ in range(n): + e = torch.randn_like(zs) + Jv = torch.autograd.functional.jvp(blk.nc_force, zs, e)[1] + tot += (Jv.pow(2).sum() / e.pow(2).sum()).item() + return math.sqrt(tot / n) + +print(f"{'C':>5} {'H':>3} {'init_res':>9} {'|Jnc|init':>10} {'|Jnc|@100':>10} {'growth/step':>11}") +for C in (256, 512, 768, 1024, 1536, 2048): + torch.manual_seed(0) + H = C // 32 + blk = EQBlock(C, H, 256, 256, attn_mode='thick') + blk.qknorm = True; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0; blk.nbrake = 0; blk._cstep = None + with torch.no_grad(): + blk.WO.mul_(0.1); blk.pj.mul_(0.1) # resinit 0.1 (match the sweep) + idx, y = get_batch('train', 8, 256) + xin = blk.embed(idx).detach() + zs = relax(blk, xin.clone(), xin, 150, 0.1) + res0 = (relax(blk, zs, xin, 1, 0.1) - zs).norm().item() / zs.norm().item() + j0 = jnc_norm(blk, zs) + # 100 BPTT steps at a FIXED lr to see how fast ‖J_nc‖ grows (the destabilizing drive) + opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) + for _ in range(100): + ix, yy = get_batch('train', 8, 256) + g = bptt_step(blk, ix, yy, 150, 0.1) + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = g.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, 5.0); opt.step() + zs2 = relax(blk, xin.clone(), xin, 150, 0.1) + j1 = jnc_norm(blk, zs2) + print(f"{C:>5} {H:>3} {res0:>9.1e} {j0:>10.3f} {j1:>10.3f} {(j1-j0)/100:>11.2e}", flush=True) diff --git a/ep_run/knockout_s3200.py b/ep_run/knockout_s3200.py new file mode 100644 index 0000000..1c05ca4 --- /dev/null +++ b/ep_run/knockout_s3200.py @@ -0,0 +1,27 @@ +import torch, pickle, math +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; eps=0.1; B=8; T=256; N=3000 +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev) +def relax_floor(alpha): + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True + with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + blk.WO.mul_(alpha) # scale attention OUTPUT contribution (alpha=0 -> no attention) + xin=blk.embed(idx).detach(); z=xin.clone(); ress=[] + for t in range(N): + z2=z+eps*blk.force(z,xin).detach() + r=(z2-z).norm().item()/(z.norm().item()+1e-9); ress.append(r); z=z2 + if not math.isfinite(r) or r>1e3: return ('DIVERGED',t,r,0) + tail=ress[-800:] + return (ress[149] if len(ress)>149 else None, ress[-1], min(tail), max(tail)) +print("=== knockout: scale attention output (WO*alpha) on redx s3200, eval_relax 3000 steps ===") +print("alpha=1 is the cycling operator; if cycle dies (res->0, monotone) as alpha falls -> attention asymmetry drives it") +for a in [1.0, 0.7, 0.4, 0.2, 0.0]: + r150,rlast,tmin,tmax=relax_floor(a) + osc = (tmax-tmin) + print(f" alpha={a}: res(150)={r150:.3e} res(3000)={rlast:.3e} tail[min={tmin:.2e},max={tmax:.2e}] osc={osc:.2e} {'CYCLE' if osc>1e-3 and rlast>1e-3 else 'converged' if rlast<1e-3 else 'floored'}") +print("=== DONE ===") diff --git a/ep_run/local_layers.py b/ep_run/local_layers.py new file mode 100644 index 0000000..db73fb8 --- /dev/null +++ b/ep_run/local_layers.py @@ -0,0 +1,305 @@ +"""Local-learning variants of Linear supporting BP / FA / DFA / sign-sym methods. + +LocalLinear is a drop-in replacement for nn.Linear that selects its backward +computation based on the `method` argument: + + bp: standard autograd (nn.Linear behavior) + fa: custom autograd, backward uses a fixed random matrix B in place of W.T + (Lillicrap-style Feedback Alignment, per projection) + sign_sym: custom autograd, backward uses sign(W) in place of W.T (Xiao 2018) + dfa: forward uses normal autograd (so upstream params like embeddings / + LayerNorm still get BP gradients). Input is cached during forward. + After loss.backward(), call `apply_dfa_update(model, e_L)` to + OVERWRITE LocalLinear .grad with DFA-computed update. LocalLinear + weights thus receive direct projection updates while non-LocalLinear + params (embeddings, LN) retain BP gradients (pragmatic hybrid). + +For DFA, call `initialize_dfa_targets(model, target_dim)` once after model +construction, then each training step: + 1. standard forward + 2. compute loss, loss.backward() (fills BP .grad on everything) + 3. compute e_L = dL/dlogits analytically (for LM: softmax(logits)-onehot) + 4. call `apply_dfa_update(model, e_L)` to overwrite LocalLinear .grad + 5. optimizer.step() +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LinearFA(torch.autograd.Function): + """Linear forward, FA backward (replace W.T with fixed random B).""" + + @staticmethod + def forward(ctx, x, W, B, bias): + ctx.save_for_backward(x, W, B) + ctx.has_bias = bias is not None + out = x @ W.t() + if bias is not None: + out = out + bias + return out + + @staticmethod + def backward(ctx, grad_out): + x, W, B = ctx.saved_tensors + # True BP would use W here (shape out x in). FA replaces with random B of same shape. + grad_x = grad_out @ B + # grad_W is standard outer product of grad_out and x (summed over leading dims) + grad_W = grad_out.reshape(-1, grad_out.shape[-1]).t() @ x.reshape(-1, x.shape[-1]) + grad_B = None # B is fixed random + grad_bias = None + if ctx.has_bias: + grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0) + return grad_x, grad_W, grad_B, grad_bias + + +class LinearSignSym(torch.autograd.Function): + """Linear forward, sign-symmetric backward with rescaling. + + B = sign(W) · ||W||_F / sqrt(numel(W)) + The rescale matches sign(W)'s magnitude to W's typical element magnitude, + avoiding the 50x gradient blowup that pure sign(W) caused. + """ + + @staticmethod + def forward(ctx, x, W, bias): + ctx.save_for_backward(x, W) + ctx.has_bias = bias is not None + out = x @ W.t() + if bias is not None: + out = out + bias + return out + + @staticmethod + def backward(ctx, grad_out): + x, W = ctx.saved_tensors + # Rescaled sign: scale so that ||B||_F ≈ ||W||_F + scale = W.norm() / (W.numel() ** 0.5 + 1e-8) + sign_W_scaled = torch.sign(W) * scale + grad_x = grad_out @ sign_W_scaled + grad_W = grad_out.reshape(-1, grad_out.shape[-1]).t() @ x.reshape(-1, x.shape[-1]) + grad_bias = None + if ctx.has_bias: + grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0) + return grad_x, grad_W, grad_bias + + +class LocalLinear(nn.Module): + """nn.Linear drop-in with method-dispatched backward (bp/fa/dfa/sign_sym/dfa_block). + + fa_init_mode (only used when method='fa'): + gaussian: B ~ N(0, init_std) (Lillicrap default, current) + orthogonal: B = Haar orthogonal × scale (#1: JL-isometric, scaled to match BP grad norm) + ortho_he: B = Haar orthogonal × sqrt(2/out) (#2: He-init for backward signal) + sparse: B with k non-zeros per row, signs ±1, scaled (#4: structured sparse) + + fa_grape (only with method='fa'): if True, B is updated each step via cosine alignment + to the rank-1 JVP Jacobian estimate Ĵ = (W p) p^T. Implements GrAPE + (Caillon et al., ICLR 2026) per-layer. Forward only — no W^T transport. + """ + + def __init__(self, in_features, out_features, bias=False, method="bp", init_std=0.02, + fa_init_mode="gaussian", fa_sparse_k=None, fa_grape=False, fa_grape_n_probe=32): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.method = method + self._fa_grape = (method == "fa") and fa_grape + self._fa_grape_n_probe = fa_grape_n_probe + + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + nn.init.normal_(self.weight, mean=0.0, std=init_std) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + + if method == "fa": + B = torch.empty(out_features, in_features) + if fa_init_mode == "gaussian": + nn.init.normal_(B, mean=0.0, std=init_std) + elif fa_init_mode == "orthogonal": + # Haar orthogonal (semi-orthogonal for non-square), scaled to match BP grad norm. + # BP backward: grad_x = grad_out @ W has norm ~ sqrt(in) * std(W) * ||grad_out||. + # Pure orthogonal preserves norm (|| || stays). Scale to match BP's natural shrinkage. + nn.init.orthogonal_(B) + scale = (in_features ** 0.5) * init_std + B.mul_(scale) + elif fa_init_mode == "ortho_he": + # He init for backward: variance = 2/out_features (matches ReLU-friendly backward) + nn.init.orthogonal_(B) + scale = (2.0 / out_features) ** 0.5 + B.mul_(scale) + elif fa_init_mode == "sparse": + # k non-zero entries per row, signs ±1, scaled so row L2 norm matches Gaussian B + k = fa_sparse_k if fa_sparse_k is not None else max(1, in_features // 16) + B.zero_() + for i in range(out_features): + idx = torch.randperm(in_features)[:k] + signs = (torch.randint(0, 2, (k,)).float() * 2 - 1) + B[i, idx] = signs + # Scale to match Gaussian B's row variance: row_var(Gaussian) = in_features * init_std^2 + # row_var(sparse) = k (after scale 1). To match: scale = init_std * sqrt(in_features/k) + B.mul_(init_std * (in_features / k) ** 0.5) + else: + raise ValueError(f"Unknown fa_init_mode: {fa_init_mode}") + # GrAPE: B is updated via JVP-cosine alignment (not via standard optimizer); + # store as Parameter with requires_grad=False so we can update in-place. + if self._fa_grape: + self.B = nn.Parameter(B, requires_grad=False) + else: + self.register_buffer("B", B) + + if method == "dfa": + # B_dfa shape (out_features, target_dim); set via initialize_dfa_targets + self.register_buffer("B_dfa", None) + self._dfa_cached_input = None + + if method == "dfa_block": + # B_dfa_block: (out_features, d_block) — projects block-output-error to layer output + # Set via initialize_dfa_block_targets + self.register_buffer("B_dfa_block", None) + self._dfa_block_cached_input = None + + def set_dfa_target_dim(self, target_dim, init_std=0.02): + assert self.method == "dfa" + B = torch.empty(self.out_features, target_dim, device=self.weight.device) + nn.init.normal_(B, mean=0.0, std=init_std) + self.B_dfa = B + + @torch.no_grad() + def grape_align_step(self, lr_b=0.01, normalize_columns=False): + """GrAPE: update B toward rank-1 JVP Jacobian estimate via cosine alignment. + + For linear y = W x: J = W. JVP at random p: g = W p. Estimate Ĵ = (1/N) Σ g_i p_i^T → W. + Forward only (uses W in forward computation, no W^T transport). + + normalize_columns: paper (Eq. 6) does column-normalize B. We default off because + our per-linear FA needs magnitude match (B → W in magnitude AND direction); + column-norm prevents matching W's row magnitudes. + """ + if not self._fa_grape: + return + N = self._fa_grape_n_probe + device, dtype = self.weight.device, self.weight.dtype + # Random Gaussian perturbation p ~ N(0, I) in input space + p = torch.randn(N, self.in_features, device=device, dtype=dtype) + # Forward: g = p @ W^T (one matrix multiply) + g = F.linear(p, self.weight) # (N, out_features) + # Rank-1 Jacobian estimate: Ĵ = (1/N) g^T @ p, shape (out, in) + J_hat = (g.t() @ p) / N + # Cosine alignment gradient + B_norm = self.B.norm() + J_norm = J_hat.norm() + if B_norm < 1e-8 or J_norm < 1e-8: + return + cos_val = (self.B * J_hat).sum() / (B_norm * J_norm) + # ∂(1 - cos)/∂B = -Ĵ/(||B||·||Ĵ||) + cos · B/||B||² + grad = -J_hat / (B_norm * J_norm) + cos_val * self.B / (B_norm ** 2) + self.B.add_(grad, alpha=-lr_b) + if normalize_columns: + col_norms = self.B.norm(dim=0, keepdim=True).clamp_min(1e-8) + self.B.div_(col_norms) + + def forward(self, x): + if self.method == "bp": + return F.linear(x, self.weight, self.bias) + if self.method == "fa": + return LinearFA.apply(x, self.weight, self.B, self.bias) + if self.method == "sign_sym": + return LinearSignSym.apply(x, self.weight, self.bias) + if self.method == "dfa": + # Cache input for later manual DFA update (will overwrite BP .grad) + self._dfa_cached_input = x.detach() + return F.linear(x, self.weight, self.bias) + if self.method == "dfa_block": + self._dfa_block_cached_input = x.detach() + return F.linear(x, self.weight, self.bias) + raise ValueError(f"Unknown method: {self.method}") + + def dfa_compute_grad(self, e_L): + """Set self.weight.grad and self.bias.grad from global error e_L. + + e_L shape (..., target_dim). delta = e_L @ B_dfa.T, shape (..., out_features). + ΔW = sum_n delta_n outer input_n, where inputs are cached from forward. + """ + assert self.method == "dfa" + assert self._dfa_cached_input is not None, "DFA forward not called or cache cleared" + assert self.B_dfa is not None, "DFA target_dim not set (call initialize_dfa_targets)" + + delta = e_L @ self.B_dfa.t() # (..., out_features) + delta_flat = delta.reshape(-1, self.out_features) + inp_flat = self._dfa_cached_input.reshape(-1, self.in_features) + + grad_W = delta_flat.t() @ inp_flat # (out_features, in_features) + + if self.weight.grad is None: + self.weight.grad = grad_W.clone() + else: + self.weight.grad.copy_(grad_W) + + if self.bias is not None: + grad_b = delta_flat.sum(dim=0) + if self.bias.grad is None: + self.bias.grad = grad_b.clone() + else: + self.bias.grad.copy_(grad_b) + + self._dfa_cached_input = None + + def extra_repr(self): + return f"in={self.in_features}, out={self.out_features}, method={self.method}" + + +def initialize_dfa_targets(model, target_dim): + """Must be called once after model construction and device placement, for DFA mode.""" + for module in model.modules(): + if isinstance(module, LocalLinear) and module.method == "dfa": + module.set_dfa_target_dim(target_dim) + + +def apply_dfa_update(model, e_L): + """Iterate over all LocalLinear(dfa) modules and populate their .grad from e_L.""" + for module in model.modules(): + if isinstance(module, LocalLinear) and module.method == "dfa": + module.dfa_compute_grad(e_L) + + +def initialize_dfa_block_targets(model, d_block, init_std=0.02): + """For dfa_block mode: each LocalLinear gets a random B_dfa_block of shape (out, d_block).""" + for module in model.modules(): + if isinstance(module, LocalLinear) and module.method == "dfa_block": + B = torch.empty(module.out_features, d_block, device=module.weight.device) + nn.init.normal_(B, mean=0.0, std=init_std) + module.B_dfa_block = B + + +def apply_dfa_block_update(block, block_output_error): + """Apply DFA-within-block updates to all LocalLinear(dfa_block) in `block`. + + block_output_error: (B, T, d_block) — gradient at the block's output. + Each linear's grad: ΔW = (block_output_error @ B_dfa_block.T)^T @ cached_input + """ + err = block_output_error.detach() + err_flat = err.reshape(-1, err.size(-1)) # (N, d_block) + N = err_flat.size(0) + for module in block.modules(): + if isinstance(module, LocalLinear) and module.method == "dfa_block": + assert module.B_dfa_block is not None, "Call initialize_dfa_block_targets first" + assert module._dfa_block_cached_input is not None, "Forward not called" + # delta: (N, out_features) = err_flat @ B_dfa_block.T + delta = err_flat @ module.B_dfa_block.t() + inp_flat = module._dfa_block_cached_input.reshape(-1, module.in_features) + grad_W = (delta.t() @ inp_flat) / max(N, 1) + if module.weight.grad is None: + module.weight.grad = grad_W.clone() + else: + module.weight.grad.copy_(grad_W) + if module.bias is not None: + grad_b = delta.sum(dim=0) / max(N, 1) + if module.bias.grad is None: + module.bias.grad = grad_b.clone() + else: + module.bias.grad.copy_(grad_b) + module._dfa_block_cached_input = None diff --git a/ep_run/lt_ep_anderson.py b/ep_run/lt_ep_anderson.py new file mode 100644 index 0000000..7682c50 --- /dev/null +++ b/ep_run/lt_ep_anderson.py @@ -0,0 +1,54 @@ +"""Decisive test for the Anderson idea: at LOW damping (expressive attention), can a fixed-point +SOLVER (Anderson acceleration, DEQ-style) converge the free phase where plain fixed-step relaxation +cannot? If yes -> we get convergence from the solver, not from suppressing attention with damping.""" +import math, torch +from lt_ep_train import EQBlock, get_batch +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +B, T, C, H = 16, 64, 128, 4 +blk = EQBlock(C, H, 256, T, attn_mode='real') +idx, y = get_batch('train', B, T) +xin = blk.embed(idx).detach() +eps = 0.05 + + +def gmap(z): # relaxation map; its fixed point = the equilibrium + with torch.no_grad(): + return z + eps * blk.force(z, xin).detach() + + +def plain(z0, steps=200): + z = z0.clone() + for _ in range(steps): + z = gmap(z) + return ((gmap(z) - z).norm() / (z.norm() + 1e-9)).item() + + +def anderson(z0, m=6, max_iter=120, tol=1e-6, lam=1e-4): + Bs, d = z0.shape[0], z0[0].numel() + X = torch.zeros(Bs, m, d, device=dev); Fb = torch.zeros(Bs, m, d, device=dev) + X[:, 0] = z0.reshape(Bs, d); Fb[:, 0] = gmap(z0).reshape(Bs, d) + X[:, 1] = Fb[:, 0]; Fb[:, 1] = gmap(X[:, 1].view_as(z0)).reshape(Bs, d) + Hm = torch.zeros(Bs, m + 1, m + 1, device=dev); Hm[:, 0, 1:] = 1; Hm[:, 1:, 0] = 1 + yv = torch.zeros(Bs, m + 1, 1, device=dev); yv[:, 0] = 1 + r, k = 1.0, 2 + for k in range(2, max_iter): + n = min(k, m) + Gm = Fb[:, :n] - X[:, :n] + Hm[:, 1:n + 1, 1:n + 1] = torch.bmm(Gm, Gm.transpose(1, 2)) + lam * torch.eye(n, device=dev)[None] + alpha = torch.linalg.solve(Hm[:, :n + 1, :n + 1], yv[:, :n + 1])[:, 1:n + 1, 0] + X[:, k % m] = torch.bmm(alpha[:, None], Fb[:, :n])[:, 0] + Fb[:, k % m] = gmap(X[:, k % m].view_as(z0)).reshape(Bs, d) + r = ((Fb[:, k % m] - X[:, k % m]).norm() / (Fb[:, k % m].norm() + 1e-9)).item() + if r < tol or not math.isfinite(r): + break + return r, k + 1 + + +print("free-phase convergence: plain relax (200 steps) vs Anderson — real attention, eps=0.05") +print(f"{'damp c':>7} {'plain_res':>11} {'anderson_res':>13} {'and_iters':>10}") +for c in [0.0, 0.25, 0.5, 1.0, 2.0, 4.0]: + blk.c = c + pr = plain(xin.clone()) + ar, ak = anderson(xin.clone()) + print(f"{c:>7.2f} {pr:>11.2e} {ar:>13.2e} {ak:>10d}") diff --git a/ep_run/lt_ep_attention.py b/ep_run/lt_ep_attention.py new file mode 100644 index 0000000..c411de6 --- /dev/null +++ b/ep_run/lt_ep_attention.py @@ -0,0 +1,129 @@ +"""option 2 / MVP-A: AEP for the LM's CAUSAL attention (equilibrium reformulation), on real text. + +The LM's feedforward causal attention is reformulated as a damped equilibrium block: + state z (B,T,C); input embedding x_in is clamped (boundary); + force F(z) = -(z - x_in) + s*(causal_attn(z) - c*z) [c>0: contraction -> stable fixed pt] + causal_attn = O softmax(QK^T/sqrt d , causal) V [non-conservative: independent Q/K/V/O] +Relax to z*, read out logits = z* W_head, cost = next-token cross-entropy. +Train Q/K/V/O with AEP (free + +/-beta nudged, centered, correction clipped) vs naive-EP, and +compare each to the ground-truth BPTT gradient (cosine) on the attention params -- on Shakespeare. +""" +import math, pickle, numpy as np, torch, torch.nn.functional as F +from pathlib import Path + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +DD = Path('/tmp/lt_ep/data/shakespeare_char') +vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] +B, T, C, H = 16, 64, 128, 4 +dh = C // H +ATTN = ('WQ', 'WK', 'WV', 'WO') + + +def get_batch(): + data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r') + ix = torch.randint(len(data) - T - 1, (B,)) + x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) + return x.to(dev), y.to(dev) + + +P = dict(WQ=torch.randn(C, C, device=dev) / math.sqrt(C), + WK=torch.randn(C, C, device=dev) / math.sqrt(C), + WV=torch.randn(C, C, device=dev) / math.sqrt(C), + WO=torch.randn(C, C, device=dev) / math.sqrt(C)) +for v in P.values(): + v.requires_grad_(True) +Whead = (torch.randn(C, vocab, device=dev) / math.sqrt(C)).requires_grad_(True) +tok_emb = torch.randn(vocab, C, device=dev) * 0.02 # fixed embedding (we test Q/K/V/O grads) +pos_emb = torch.randn(T, C, device=dev) * 0.02 +CMASK = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev)) +DAMP, S = 1.0, 1.0 + + +def embed(idx): + return tok_emb[idx] + pos_emb[None, :, :] + + +def causal_attn(z): + q = (z @ P['WQ']).view(B, T, H, dh).transpose(1, 2) + k = (z @ P['WK']).view(B, T, H, dh).transpose(1, 2) + v = (z @ P['WV']).view(B, T, H, dh).transpose(1, 2) + a = (q @ k.transpose(-2, -1)) / math.sqrt(dh) + a = a.masked_fill(~CMASK[:T, :T], float('-inf')) + a = torch.softmax(a, dim=-1) + return (a @ v).transpose(1, 2).reshape(B, T, C) @ P['WO'] + + +def force(z, x_in): + return -(z - x_in) + S * (causal_attn(z) - DAMP * z) + + +def relax(z, x_in, steps, eps): + for _ in range(steps): + with torch.no_grad(): + z = z + eps * force(z, x_in) + return z.detach() + + +def ce(z, y): + return F.cross_entropy((z @ Whead).reshape(-1, vocab), y.reshape(-1)) + + +def aep_grad(x_in, y, T1, T2, eps, beta, aep): + zs = relax(embed_in.clone(), x_in, T1, eps) + def nudged(sign): + z = zs.clone() + for _ in range(T2): + with torch.enable_grad(): + zz = z.detach().requires_grad_(True) + g, = torch.autograd.grad(ce(zz, y), zz) + with torch.no_grad(): + f = force(z, x_in) - sign * beta * g + if aep: + v = (z - zs).detach() + Jv = torch.autograd.functional.jvp(causal_attn, zs, v)[1] + JTv = torch.autograd.functional.vjp(causal_attn, zs, v)[1] + corr = S * (Jv - JTv) + cn, fn = corr.norm(), f.norm() + 1e-8 + corr = corr * (fn / cn) if cn > fn else corr + f = f - corr + z = z + eps * f + return z.detach() + zp, zm = nudged(+1.0), nudged(-1.0) + a = ((zm - zp) / (2 * beta)).detach() + with torch.enable_grad(): + s = (a * force(zs.detach(), x_in)).sum() + return torch.autograd.grad(s, list(P.values()), allow_unused=True) + + +def bptt_grad(x_in, y, T1, eps): + z = embed_in.clone().requires_grad_(True) + for _ in range(T1): + z = z + eps * force(z, x_in) + return torch.autograd.grad(ce(z, y), list(P.values()), allow_unused=True) + + +def cos(g, gb): + cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)] + return sum(cs) / len(cs), cs + + +idx, y = get_batch() +x_in = embed(idx).detach() +embed_in = x_in.clone() # init state = embedding + +# free-eq residual (is there a stable fixed point?) +zs = relax(embed_in.clone(), x_in, 200, 0.1) +r = (relax(zs, x_in, 1, 0.1) - zs).norm().item() / (zs.norm().item() + 1e-9) +print(f"LM causal damped-equilibrium attention on Shakespeare (B={B} T={T} C={C} H={H}, damp={DAMP})") +print(f" free-phase residual = {r:.2e} ({'stable fixed point' if r < 1e-2 else 'NOT converged'})") + +gb = bptt_grad(x_in, y, 120, 0.1) +gn = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=False) +ga = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=True) +mn, csn = cos(gn, gb) +ma, csa = cos(ga, gb) +print(f"\n attention-param gradient cosine vs BPTT:") +print(f" naive-EP : mean {mn:+.3f} per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csn))) +print(f" AEP : mean {ma:+.3f} per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csa))) diff --git a/ep_run/lt_ep_compare.py b/ep_run/lt_ep_compare.py new file mode 100644 index 0000000..0159855 --- /dev/null +++ b/ep_run/lt_ep_compare.py @@ -0,0 +1,69 @@ +"""option 2: compare local attention-gradient quality vs true BP on the LM's attention. +Uses the project's OWN LocalCausalSelfAttention: FA (feedback alignment) and fuse_attn_local +(the hand-derived SoftmaxValueMixLocalFn local backward). Reports attention-param grad cosine +vs BP, to set against the AEP result (0.993) from the equilibrium reformulation.""" +import math, pickle, numpy as np, torch, torch.nn.functional as F +from pathlib import Path +from model_local import LocalCausalSelfAttention, LocalGPTConfig + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +DD = Path('/tmp/lt_ep/data/shakespeare_char') +vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] +B, T, C, H = 16, 64, 128, 4 + + +def get_batch(): + data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r') + ix = torch.randint(len(data) - T - 1, (B,)) + x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) + return x.to(dev), y.to(dev) + + +idx, y = get_batch() +tok = torch.randn(vocab, C, device=dev) * 0.02 +pos = torch.randn(T, C, device=dev) * 0.02 +EMB = (tok[idx] + pos[None]).detach() +Whead = torch.randn(C, vocab, device=dev) / math.sqrt(C) + + +def make(method, fuse): + cfg = LocalGPTConfig(block_size=T, vocab_size=vocab, n_head=H, n_embd=C, + attn_mode='softmax', method=method, fuse_attn_local=fuse, dropout=0.0, bias=False) + return LocalCausalSelfAttention(cfg).to(dev) + + +bp = make('bp', False) +fa = make('fa', False) +fuse = make('bp', True) +# identical weights across the three so gradients are comparable +for m in (fa, fuse): + for p in ('q_proj', 'k_proj', 'v_proj', 'o_proj'): + getattr(m, p).weight.data.copy_(getattr(bp, p).weight.data) + + +def grads(model): + for p in model.parameters(): + if p.grad is not None: + p.grad = None + o = model(EMB) + F.cross_entropy((o @ Whead).reshape(-1, vocab), y.reshape(-1)).backward() + return [getattr(model, p).weight.grad for p in ('q_proj', 'k_proj', 'v_proj', 'o_proj')] + + +def cos(g, gb): + cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)] + return sum(cs) / len(cs), cs + + +gb = grads(bp) +gfa = grads(fa) +gfuse = grads(fuse) +names = ('WQ', 'WK', 'WV', 'WO') +mfa, cfa = cos(gfa, gb) +mfu, cfu = cos(gfuse, gb) +print("FEEDFORWARD attention (project's own code), grad cosine vs BP on Shakespeare:") +print(f" FA (feedback align) : mean {mfa:+.3f} " + " ".join(f"{n}={c:+.2f}" for n, c in zip(names, cfa))) +print(f" fuse_attn_local (SoftmaxValueMixLocalFn): mean {mfu:+.3f} " + " ".join(f"{n}={c:+.2f}" for n, c in zip(names, cfu))) +print("\n(compare to AEP on equilibrium attention: mean +0.993)") diff --git a/ep_run/lt_ep_diag.py b/ep_run/lt_ep_diag.py new file mode 100644 index 0000000..1cec665 --- /dev/null +++ b/ep_run/lt_ep_diag.py @@ -0,0 +1,57 @@ +"""Diagnostic: WHY does EP training destabilize? Test the hypothesis (Ernoult 2019): +EP == BPTT iff the free phase has CONVERGED (+ small beta). So log, per training step: + - free-phase residual ||Pi(z*+eF)-z*||/||z*|| (is the fixed point still there?) + - cosine(EP-grad, BPTT-grad) over the block params (is EP still tracking the true grad?) +If cosine starts ~1 and stays ~1 until the residual blows up -> it's loss of convergence, not beta. +""" +import math, torch, torch.nn.functional as F +from lt_ep_train import EQBlock, ep_step, bptt_step, relax, get_batch, ce + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +T1, T2, eps, beta, B, T = 80, 15, 0.1, 0.02, 32, 64 +blk = EQBlock(128, 4, 256, T, s=1.0, c=1.0) +opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) +BP = (blk.WQ, blk.WK, blk.WV, blk.WO, blk.Wm) + + +def resid(idx): + xin = blk.embed(idx).detach() + zs = relax(blk, xin.clone(), xin, T1, eps) + zn = relax(blk, zs, xin, 1, eps) + return (zn - zs).norm().item() / (zs.norm().item() + 1e-9) + + +def gcos(idx, y): + gep = ep_step(blk, idx, y, T1, T2, eps, beta) + gbp = bptt_step(blk, idx, y, T1, eps) + fa, fb = [], [] + for p in BP: + a, b = gep.get(id(p)), gbp.get(id(p)) + if a is not None and b is not None and torch.isfinite(a).all() and torch.isfinite(b).all(): + fa.append(a.flatten()); fb.append(b.flatten()) + if not fa: + return float('nan'), gep + return F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item(), gep + + +print(f"{'step':>4} {'free_resid':>11} {'cos(EP,BPTT)':>13} {'val_CE':>8}") +for step in range(1, 161): + idx, y = get_batch('train', B, T) + r = resid(idx) + c, gep = gcos(idx, y) + if step % 10 == 0 or step <= 5: + with torch.no_grad(): + vi, vy = get_batch('val', B, T) + xin = blk.embed(vi).detach() + v = ce(blk, relax(blk, xin.clone(), xin, T1, eps), vy).item() + print(f"{step:>4} {r:>11.2e} {c:>13.3f} {v:>8.3f}", flush=True) + # apply EP grads (the actual unstable training) + if all((g is None) or torch.isfinite(g).all() for g in gep.values()): + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = gep.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) + opt.step() + else: + print(f"{step:>4} NON-FINITE EP grad -> would skip", flush=True) diff --git a/ep_run/lt_ep_ffn.py b/ep_run/lt_ep_ffn.py new file mode 100644 index 0000000..8c060c7 --- /dev/null +++ b/ep_run/lt_ep_ffn.py @@ -0,0 +1,119 @@ +"""option 2 / H1: replace the FFN-FA (the abandonment reason) with EP's Hopfield-memory E^mem. +Compare attention-FFN gradient quality vs true BP: + FA-FFN : the project's LocalMLP(method='fa') -> expect FA's signature failure on the upstream layer + EP-FFN : Hopfield memory E_mem = 0.5a||h-x||^2 - sum relu(hW)^2 (CONSERVATIVE -> plain EP, no AEP) + trained by centered energy-EP from free/nudged per-token equilibria. +""" +import math, pickle, numpy as np, torch, torch.nn.functional as F +from pathlib import Path +from model_local import LocalMLP, LocalGPTConfig + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +DD = Path('/tmp/lt_ep/data/shakespeare_char') +vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] +B, T, C, Mm = 16, 64, 128, 256 + + +def get_batch(): + data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r') + ix = torch.randint(len(data) - T - 1, (B,)) + x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) + return x.to(dev), y.to(dev) + + +idx, y = get_batch() +tok = torch.randn(vocab, C, device=dev) * 0.02 +pos = torch.randn(T, C, device=dev) * 0.02 +XIN = (tok[idx] + pos[None]).detach() +Whead = torch.randn(C, vocab, device=dev) / math.sqrt(C) + + +# ---------- FA-FFN (project's own LocalMLP) ---------- +def make_mlp(method): + cfg = LocalGPTConfig(block_size=T, vocab_size=vocab, n_embd=C, method=method, dropout=0.0, bias=False) + return LocalMLP(cfg).to(dev) + + +bp = make_mlp('bp') +fa = make_mlp('fa') +fa.fc.weight.data.copy_(bp.fc.weight.data); fa.proj.weight.data.copy_(bp.proj.weight.data) + + +def mlp_grads(m): + for p in m.parameters(): + p.grad = None + out = m(XIN) + F.cross_entropy(((XIN + out) @ Whead).reshape(-1, vocab), y.reshape(-1)).backward() + return [m.fc.weight.grad, m.proj.weight.grad] + + +gb = mlp_grads(bp); gfa = mlp_grads(fa) +cfa = [F.cosine_similarity(a.flatten(), b.flatten(), 0).item() for a, b in zip(gfa, gb)] + + +# ---------- EP-FFN (Hopfield memory E_mem, conservative) ---------- +ALPHA = 2.0 +W = (torch.randn(C, Mm, device=dev) * 0.3 / math.sqrt(C)).requires_grad_(True) + + +def E_mem(h, x): + return 0.5 * ALPHA * ((h - x) ** 2).sum() - (F.relu(h @ W) ** 2).sum() + + +def force(h, x, cg=False): + with torch.enable_grad(): + hr = h if h.requires_grad else h.detach().requires_grad_(True) + g, = torch.autograd.grad(E_mem(hr, x), hr, create_graph=cg) + return -g + + +def relax(h, x, steps, eps): + for _ in range(steps): + with torch.enable_grad(): + f = force(h, x).detach() + with torch.no_grad(): + h = h + eps * f + return h.detach() + + +def ce(h): + return F.cross_entropy((h @ Whead).reshape(-1, vocab), y.reshape(-1)) + + +def ep_grad(x, T1, T2, eps, beta): + hs = relax(x.clone(), x, T1, eps) + def nudge(sign): + h = hs.clone() + for _ in range(T2): + with torch.enable_grad(): + hh = h.detach().requires_grad_(True) + g, = torch.autograd.grad(ce(hh), hh) + with torch.no_grad(): + h = h + eps * (force(h, x).detach() - sign * beta * g) + return h.detach() + hp, hm = nudge(+1), nudge(-1) + with torch.enable_grad(): + gp, = torch.autograd.grad(E_mem(hp, x), W) + gm, = torch.autograd.grad(E_mem(hm, x), W) + return ((gp - gm) / (2 * beta)).detach(), hs + + +def bptt_grad(x, T1, eps): + h = x.clone().requires_grad_(True) + for _ in range(T1): + h = h + eps * force(h, x, cg=True) + return torch.autograd.grad(ce(h), W)[0] + + +hs = relax(XIN.clone(), XIN, 200, 0.1) +r = (relax(hs, XIN, 1, 0.1) - hs).norm().item() / (hs.norm().item() + 1e-9) +gep, _ = ep_grad(XIN, 120, 20, 0.1, 0.02) +gbp_ep = bptt_grad(XIN, 120, 0.1) +cep = F.cosine_similarity(gep.flatten(), gbp_ep.flatten(), 0).item() + +print("H1: FFN gradient quality vs true BP, on Shakespeare LM block") +print(f" FA-FFN (LocalMLP, method=fa) : fc={cfa[0]:+.3f} proj={cfa[1]:+.3f} mean={sum(cfa)/2:+.3f}") +print(f" EP-FFN (Hopfield E_mem) : W_mem cosine = {cep:+.3f} (free-phase residual {r:.1e})") +print(f"\n -> FA fails on the upstream FFN layer (fc); EP-memory gives a faithful local gradient.") diff --git a/ep_run/lt_ep_stack.py b/ep_run/lt_ep_stack.py new file mode 100644 index 0000000..327b143 --- /dev/null +++ b/ep_run/lt_ep_stack.py @@ -0,0 +1,165 @@ +"""Spring-coupled equilibrium STACK — EP through depth with no protocol. +Inter-block coupling is a conservative spring energy sum_k gamma/2 ||z_k - z_{k-1}||^2 (z_0 +sprung to the input clamp x). The cost pulls z_K; spring REACTION forces (Newton's 3rd law) +carry the tension down the chain; EP/VF + AEP correction on the non-conservative block internals +is unchanged — the stack is just one bigger force field. +Probe: per-block gradient cosine vs BPTT-through-the-joint-relaxation. The decisive number is +block-0's cosine: did the tension reach the bottom?""" +import math, time, torch, torch.nn.functional as F +from lt_ep_train import get_batch, vocab +dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + +class EQStack: + def __init__(self, K, C, H, T, gamma=1.0, c=1.0): + g = lambda *sh, sc: (torch.randn(*sh, device=dev) * sc).requires_grad_(True) + z1 = lambda n, v: torch.full((n,), float(v), device=dev).requires_grad_(True) + self.K, self.C, self.H, self.dh, self.T = K, C, H, C // H, T + self.gamma, self.c = gamma, c + self.tok = g(vocab, C, sc=0.02); self.pos = g(T, C, sc=0.02) + self.blocks = [] + for _ in range(K): + self.blocks.append(dict( + WQ=g(C, C, sc=1 / math.sqrt(C)), WK=g(C, C, sc=1 / math.sqrt(C)), + WV=g(C, C, sc=1 / math.sqrt(C)), WO=g(C, C, sc=1 / math.sqrt(C)), + ln1g=z1(C, 1), ln1b=z1(C, 0), ln2g=z1(C, 1), ln2b=z1(C, 0), + fc=g(C, 4 * C, sc=1 / math.sqrt(C)), fcb=z1(4 * C, 0), + pj=g(4 * C, C, sc=1 / math.sqrt(4 * C)), pjb=z1(C, 0))) + self.Wh = g(C, vocab, sc=1 / math.sqrt(C)) + self.cmask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev)) + self.block = [self.tok, self.pos] + [p for b in self.blocks for p in b.values()] + self.allp = self.block + [self.Wh] + + def embed(self, idx): + return self.tok[idx] + self.pos[None] + + def battn(self, b, z): + B, T, H, dh, C = z.size(0), self.T, self.H, self.dh, self.C + h = F.layer_norm(z, (C,), b['ln1g'], b['ln1b']) + q = (h @ b['WQ']).view(B, T, H, dh).transpose(1, 2) + k = (h @ b['WK']).view(B, T, H, dh).transpose(1, 2) + v = (h @ b['WV']).view(B, T, H, dh).transpose(1, 2) + a = torch.softmax(((q @ k.transpose(-2, -1)) / math.sqrt(dh)).masked_fill(~self.cmask, float('-inf')), -1) + return (a @ v).transpose(1, 2).reshape(B, T, C) @ b['WO'] + + def bffn(self, b, z): + h = F.layer_norm(z, (self.C,), b['ln2g'], b['ln2b']) + return F.gelu(h @ b['fc'] + b['fcb']) @ b['pj'] + b['pjb'] + + def nc_force(self, zc): # non-conservative internals, state (K,B,T,C) + return torch.stack([self.battn(b, zc[k]) + self.bffn(b, zc[k]) + for k, b in enumerate(self.blocks)], 0) + + def force(self, zc, xin, cg=False): + zr = zc if (cg and zc.requires_grad) else zc.detach().requires_grad_(True) + below = torch.cat([xin[None], zr[:-1]], 0) + f = -self.gamma * (zr - below) + self.nc_force(zr) - self.c * zr + up = self.gamma * (zr[1:] - zr[:-1]) # reaction of the spring above (Newton's 3rd law) + return f + torch.cat([up, torch.zeros_like(zr[:1])], 0) + + +def relax(st, z, xin, steps, eps): + for _ in range(steps): + with torch.no_grad(): + z = z + eps * st.force(z, xin).detach() + return z.detach() + + +def ce(st, z, y): + return F.cross_entropy((z[-1] @ st.Wh).reshape(-1, vocab), y.reshape(-1)) + + +def grad_ce_state(st, z, y): # closed-form dCE/dz: only the top block feels y + p = torch.softmax(z[-1] @ st.Wh, -1) + gK = (p - F.one_hot(y, p.size(-1)).to(z.dtype)) @ st.Wh.t() / y.numel() + g = torch.zeros_like(z) + g[-1] = gK + return g + + +def ep_a(st, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0): + """N=2 real phases, clamp-free, AEP corr on the stack's nc part, hindsight selection.""" + Zp, Zm = zs.clone(), zs.clone() + a_prev = a_best = None + inc_min, t_best = float('inf'), 0 + for t in range(1, T2max + 1): + with torch.no_grad(): + for Z, sg in ((Zp, +r), (Zm, -r)): + f = st.force(Z, xin) - sg * grad_ce_state(st, Z, y) + v = (Z - zs).contiguous() + Jv = torch.autograd.functional.jvp(st.nc_force, zs, v)[1] + JTv = torch.autograd.functional.vjp(st.nc_force, zs, v)[1] + Z += eps * (f - (Jv - JTv)) + if t % K == 0 or t == T2max: + a_t = (Zm - Zp) / (2 * r) + if not torch.isfinite(a_t).all(): + break + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + elif inc > exit_mult * inc_min and t >= 3 * K: + break + a_prev = a_t + return (a_best if a_best is not None else a_prev).detach(), t_best + + +def ep_grads(st, idx, y, T1, eps, r, T2max): + xin = st.embed(idx).detach() + z0 = xin[None].repeat(st.K, 1, 1, 1) + zs = relax(st, z0, xin, T1, eps) + res = (relax(st, zs, xin, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) + a, tb = ep_a(st, zs, xin, y, r, T2max, eps) + with torch.enable_grad(): + x2 = st.embed(idx) + f = st.force(zs.detach(), x2, cg=True) + g = torch.autograd.grad((a * f).sum(), st.block, allow_unused=True) + return {id(p): gv for p, gv in zip(st.block, g)}, res, tb + + +def bptt_grads(st, idx, y, T1, eps): + xin = st.embed(idx) + z = (xin.detach().requires_grad_(True) * 0 + xin)[None].repeat(st.K, 1, 1, 1) + for _ in range(T1): + z = z + eps * st.force(z, xin, cg=True) + g = torch.autograd.grad(ce(st, z, y), st.allp, allow_unused=True) + return {id(p): gv for p, gv in zip(st.allp, g)} + + +if __name__ == '__main__': + torch.manual_seed(0) + K, B, T, C, H = 2, 16, 64, 128, 4 + st = EQStack(K, C, H, T, gamma=1.0, c=1.0) + opt = torch.optim.AdamW(st.allp, lr=1e-3, weight_decay=1e-4) + for step in range(200): # short BPTT pretrain -> realistic operating point + idx, y = get_batch('train', B, T) + g = bptt_grads(st, idx, y, 120, 0.1) + opt.zero_grad(set_to_none=True) + for p in st.allp: + p.grad = g.get(id(p)) + torch.nn.utils.clip_grad_norm_(st.allp, 5.0) + opt.step() + print(f"pretrained 200 BPTT steps (K={K} spring stack, gamma={st.gamma})", flush=True) + + groups = {'all': st.block, + 'blk0': list(st.blocks[0].values()), + 'blk1': list(st.blocks[1].values()), + 'emb': [st.tok, st.pos]} + + def cos(ga, gb, ps): + keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None] + if not keep: + return float('nan') + va = torch.cat([ga[id(p)].reshape(-1) for p in keep]) + vb = torch.cat([gb[id(p)].reshape(-1) for p in keep]) + return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() + + hdr = f"{'config':>22} {'res':>9} {'t_best':>7} " + " ".join(f"{k:>6}" for k in groups) + for bi in range(3): + idx, y = get_batch('train', B, T) + ref = bptt_grads(st, idx, y, 400, 0.1) + print(("\n" if bi else "") + hdr, flush=True) + for T1 in (150, 400): + gep, res, tb = ep_grads(st, idx, y, T1, 0.1, 0.02, 120) + print(f"{f'ep T1={T1}':>22} {res:>9.1e} {tb:>7} " + + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True) diff --git a/ep_run/lt_ep_train.py b/ep_run/lt_ep_train.py new file mode 100644 index 0000000..9974bd8 --- /dev/null +++ b/ep_run/lt_ep_train.py @@ -0,0 +1,630 @@ +"""option 2 / H2: train a full EP equilibrium transformer block on Shakespeare char-LM. + +One block = token state z relaxed to a fixed point of + F(z) = -(z - x_in) (input clamp; x_in = embed(idx)) + - dE_mem/dz (Hopfield memory E_mem = -sum relu(z Wm)^2 ; CONSERVATIVE = FFN) + + s*(causal_attn(z) - c*z) (damped causal attention ; NON-conservative) +Readout logits = z* Whead ; loss = next-token cross-entropy. + +Train modes: + ep : free + +/-beta nudged equilibria. Block+embed params via vector-field gradient + <a, dF/dtheta(z*)> with the AEP correction (clipped) on the attention part; readout + head via its own local gradient dCE/dWhead. NO backprop through the relaxation. + bptt : backprop through the unrolled relaxation (exact-gradient reference, same architecture). +Stabilisation from G: damping c, clipped AEP correction, weight-norm caps, best-val checkpoint. +""" +import argparse, math, pickle, time, json, os, numpy as np, torch, torch.nn.functional as F +from pathlib import Path + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +DD = Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe') +vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] + + +def get_batch(split, B, T): + data = np.memmap(DD / ('train.bin' if split == 'train' else 'val.bin'), dtype=np.uint16, mode='r') + ix = torch.randint(len(data) - T - 1, (B,)) + x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) + return x.to(dev), y.to(dev) + + +class EQBlock: + def __init__(self, C, H, Mm, T, s=1.0, c=1.0, attn_mode='real', gamma=0.25): + self.C, self.H, self.dh, self.s, self.c, self.T = C, H, C // H, s, c, T + self.attn_mode, self.gamma = attn_mode, gamma + self.fnoise = 0.0 # optics model: mult. noise per force eval + g = lambda *sh, sc: (torch.randn(*sh, device=dev) * sc).requires_grad_(True) + self.tok = g(vocab, C, sc=0.02); self.pos = g(T, C, sc=0.02) + self.WQ = g(C, C, sc=1 / math.sqrt(C)); self.WK = g(C, C, sc=1 / math.sqrt(C)) + self.WV = g(C, C, sc=1 / math.sqrt(C)); self.WO = g(C, C, sc=1 / math.sqrt(C)) + self.Wm = g(C, Mm, sc=0.3 / math.sqrt(C)); self.Wh = g(C, vocab, sc=1 / math.sqrt(C)) + self.P = g(C, C, sc=1 / math.sqrt(C)); self.Q = g(C, C, sc=1 / math.sqrt(C)) # monDEQ monotone op + self.mono_m = 1.0 + z1 = lambda n, v: torch.full((n,), float(v), device=dev).requires_grad_(True) + self.ln1g = z1(C, 1); self.ln1b = z1(C, 0); self.ln2g = z1(C, 1); self.ln2b = z1(C, 0) # LN affine + self.fc = g(C, 4 * C, sc=1 / math.sqrt(C)); self.fcb = z1(4 * C, 0) # untied 4x FFN + self.pj = g(4 * C, C, sc=1 / math.sqrt(4 * C)); self.pjb = z1(C, 0) + self.cmask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev)) + self.block = [self.tok, self.pos, self.WQ, self.WK, self.WV, self.WO, self.Wm, self.P, self.Q, + self.ln1g, self.ln1b, self.ln2g, self.ln2b, self.fc, self.fcb, self.pj, self.pjb] # in the force + self.allp = self.block + [self.Wh] + self.capw = (self.WQ, self.WK, self.WV, self.WO, self.Wm, self.Wh, self.fc, self.pj) + self.caps = {id(w): w.detach().norm().item() * 3.0 for w in self.capw} + + def embed(self, idx): + return self.tok[idx] + self.pos[None] + + def attn(self, z): + B = z.size(0) + q = (z @ self.WQ).view(B, self.T, self.H, self.dh).transpose(1, 2) + k = (z @ self.WK).view(B, self.T, self.H, self.dh).transpose(1, 2) + v = (z @ self.WV).view(B, self.T, self.H, self.dh).transpose(1, 2) + if getattr(self, 'qknorm', False): # Qwen3-style q/k RMSNorm: bounds logits, tames J + q = q * torch.rsqrt(q.pow(2).mean(-1, keepdim=True) + 1e-6) + k = k * torch.rsqrt(k.pow(2).mean(-1, keepdim=True) + 1e-6) + a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh) + a = torch.softmax(a.masked_fill(~self.cmask, float('-inf')), -1) + return (a @ v).transpose(1, 2).reshape(B, self.T, self.C) @ self.WO + + def attn_energy(self, z): # conservative LSE attention energy (tied value) + B = z.size(0) + q = (z @ self.WQ).view(B, self.T, self.H, self.dh).transpose(1, 2) + k = (z @ self.WK).view(B, self.T, self.H, self.dh).transpose(1, 2) + a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh) + a = a.masked_fill(~self.cmask, float('-inf')) + return -(1.0 / self.gamma) * torch.logsumexp(self.gamma * a, dim=-1).sum() + + def Emem(self, z): + return -(F.relu(z @ self.Wm) ** 2).sum() + + def tforce(self, z, xin): # pure thick force (no grad machinery) -> torch.compile + h1 = F.layer_norm(z, (self.C,), self.ln1g, self.ln1b) + h2 = F.layer_norm(z, (self.C,), self.ln2g, self.ln2b) + ff = F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb + return -(z - xin) + self.attn(h1) + ff - self.c * z + + def _noisy(self, t): # optics model: per-pass multiplicative noise + if self.fnoise > 0: + return t * (1 + self.fnoise * torch.randn_like(t)) + return t + + def nc_force(self, z): # non-conservative part of the force (for AEP/jacreg) + if self.attn_mode == 'thick': + h1 = F.layer_norm(z, (self.C,), self.ln1g, self.ln1b) + h2 = F.layer_norm(z, (self.C,), self.ln2g, self.ln2b) + return self._noisy(self.attn(h1) + (F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb)) + return self._noisy(self.s * self.attn(z)) + + def force(self, z, xin, cg=False): + with torch.enable_grad(): + zr = z if (cg and z.requires_grad) else z.detach().requires_grad_(True) + if self.attn_mode == 'thick': # DEQ-transformer block: LN + untied 4x FFN + residual + h1 = F.layer_norm(zr, (self.C,), self.ln1g, self.ln1b) + h2 = F.layer_norm(zr, (self.C,), self.ln2g, self.ln2b) + ff = F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb + return -(zr - xin) + self._noisy(self.attn(h1) + ff) - self.c * zr # -c*z: initial contraction + if self.attn_mode == 'mono': # monDEQ: structurally-monotone contraction + gm, = torch.autograd.grad(self.Emem(zr), zr, create_graph=cg) + PtP = self.P.t() @ self.P # PSD -> sym(J) = -(mI+PtP) < 0 (guaranteed) + f = (-(self.mono_m * zr + zr @ PtP) + zr @ (self.Q - self.Q.t()).t() + + xin - gm + self.s * self.attn(zr)) + return f + E = 0.5 * ((zr - xin) ** 2).sum() + self.Emem(zr) + if self.attn_mode == 'energy': # attention folded into the energy (conservative) + E = E + self.attn_energy(zr) + 0.5 * self.c * (zr ** 2).sum() # confinement -> bounded below + gz, = torch.autograd.grad(E, zr, create_graph=cg) + f = -gz + if self.attn_mode == 'real': # non-conservative attention + damping + f = f + self.s * (self.attn(zr) - self.c * zr) + return f + + +def relax(blk, z, xin, steps, eps): + cstep = getattr(blk, '_cstep', None) + if cstep is not None and blk.fnoise == 0.0: # compiled pure-thick free-phase fast path + with torch.no_grad(): + for _ in range(steps): + z = cstep(z, xin) + return z.detach() + for _ in range(steps): + with torch.no_grad(): + z = z + eps * blk.force(z, xin).detach() + return z.detach() + + +def ce(blk, z, y): + return F.cross_entropy((z @ blk.Wh).reshape(-1, vocab), y.reshape(-1)) + + +def ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, + corr_every=1, res_gate=0.0, resreg=0.0, eigreg=0.0, eig_margin=1.0): + xin0 = blk.embed(idx).detach() + zs = relax(blk, xin0.clone(), xin0, T1, eps) + res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) + res_used = res + zT, resT1 = zs, res # the T1 free-phase state (what eval/BPTT use), BEFORE refinement + if t1max > T1: # estimator refinement: relax further until tight + rnow, t = res, T1 # (controller signal `res` stays measured at T1) + while t < t1max and rnow > res_est: + zs = relax(blk, zs, xin0, 50, eps); t += 50 + rnow = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) + res_used = rnow + if res_gate > 0 and res_used > res_gate: # validity gate: off-equilibrium the EP update is + grads = {} # undefined -> apply ONLY the homeostat (jacreg) and + if jacreg > 0: # skip the nudge entirely (fast recovery steps) + er = torch.randn_like(zs) + with torch.enable_grad(): + Jv = torch.autograd.functional.jvp(blk.nc_force, zs.detach(), er, create_graph=True)[1] + R = jacreg * (Jv ** 2).sum() / (er ** 2).sum() + gr = torch.autograd.grad(R, blk.block, allow_unused=True) + grads = {id(p): g for p, g in zip(blk.block, gr) if g is not None} + return grads, res + def nudge(sign): + z = zs.clone() + for _ in range(T2): + with torch.enable_grad(): + zz = z.detach().requires_grad_(True) + g, = torch.autograd.grad(ce(blk, zz, y), zz) + g = g.clamp(-2.0, 2.0) # clip nudge so it can't blow up relax + with torch.no_grad(): + f = blk.force(z, xin0).detach() - sign * beta * g + if blk.attn_mode in ('real', 'thick'): # AEP correction (full non-conservative part) + v = (z - zs).detach() + Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v)[1] + JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v)[1] + corr = Jv - JTv + cn, fn = corr.norm(), f.norm() + 1e-8 + f = f - (corr * (fn / cn) if cn > fn else corr) + z = z + eps * f + return z.detach() + if holo == 2 and t2sel > 0: # adaptive-T2, phase-batched fast path (validated ==) + from holo_ep import holo_a_select2, holo_a_track + K = max(1, getattr(blk, 'navg', 1)) # restart-averaging: noise / sqrt(K) + acc = None + for _ in range(K): + if getattr(blk, 'track', False): # common-mode-tracking AEP (loose-tolerant) + ai, _ = holo_a_track(blk, zs, xin0, y, hr, t2sel, eps) + else: + ai, _ = holo_a_select2(blk, zs, xin0, y, hr, t2sel, eps, li=getattr(blk, 'li_avg', 0)) + acc = ai if acc is None else acc + ai + a = acc / K + elif holo > 0 and t2sel > 0: # adaptive-T2 via hindsight snapshot selection + from holo_ep import holo_a_select + a, _ = holo_a_select(blk, zs, xin0, y, holo, hr, t2sel, eps, corr_every=corr_every) + elif holo > 0: # holomorphic nudge (clamp-free, Cauchy readout) + from holo_ep import holo_a + a, _ = holo_a(blk, zs, xin0, y, holo, hr, T2, eps) + else: + zp, zm = nudge(+1), nudge(-1) + a = ((zm - zp) / (2 * beta)).detach() + grads = {} + with torch.enable_grad(): + xin = blk.embed(idx) # live (for tok/pos grad through clamp) + f = blk.force(zs.detach(), xin, cg=True) + gblk = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True) + for p, gv in zip(blk.block, gblk): + grads[id(p)] = gv + with torch.enable_grad(): + gh, = torch.autograd.grad(ce(blk, zs.detach(), y), blk.Wh) # readout local gradient + grads[id(blk.Wh)] = gh + if jacreg > 0: # soft Lyapunov: penalize non-conservative Jacobian norm + er = torch.randn_like(zs) + with torch.enable_grad(): + Jv = torch.autograd.functional.jvp(blk.nc_force, zs.detach(), er, create_graph=True)[1] + R = jacreg * (Jv ** 2).sum() / (er ** 2).sum() # Hutchinson est of ||J_nc||_F^2 + gr = torch.autograd.grad(R, blk.block, allow_unused=True) + for p, g in zip(blk.block, gr): + if g is not None: + grads[id(p)] = g if grads.get(id(p)) is None else grads[id(p)] + g + if resreg > 0 and resT1 > 7e-4: # defend z_T1 (BPTT gets this implicitly; EP at z* doesn't) + with torch.enable_grad(): + Fz = blk.tforce(zT, xin0) # deterministic thick force at z_T1 (params live, zT/xin0 detached) + Rr = (eps * Fz).pow(2).sum() / (zT.pow(2).sum() + 1e-9) # ~ (T1 residual)^2 + grr = torch.autograd.grad(Rr, blk.block, allow_unused=True) + ratio = resreg * min(1.0, resT1 / 2e-2) # ramp 0->resreg as res 7e-4->2e-2, capped + gtask = math.sqrt(sum(float((grads[id(p)] ** 2).sum()) for p in blk.block if grads.get(id(p)) is not None) + 1e-20) + gres = math.sqrt(sum(float((g ** 2).sum()) for g in grr if g is not None) + 1e-20) + lam = ratio * gtask / gres # scale penalty to `ratio` of the task-grad norm + for p, g in zip(blk.block, grr): + if g is not None: + grads[id(p)] = g * lam if grads.get(id(p)) is None else grads[id(p)] + lam * g + if eigreg > 0: # #2: leading-abscissa control (surgical, one-sided; alt to jacreg) + from eig_control import eig_penalty + ge, _om = eig_penalty(blk, zs, eigreg, eig_margin, blk.__dict__.setdefault('_eigcache', {})) + for pid, g in ge.items(): + grads[pid] = g if grads.get(pid) is None else grads[pid] + g + return grads, res + + +class Lion(torch.optim.Optimizer): + """Chen et al. 2023. Analog-hardware rationale: sign updates = fixed-amplitude pulses + (kills device write-nonlinearity), magnitude-noise immune, one momentum cap per weight.""" + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0, lars=False): + super().__init__(params, dict(lr=lr, betas=betas, weight_decay=weight_decay, lars=lars)) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + b1, b2 = group['betas'] + st = self.state.setdefault(p, {}) + if 'm' not in st: + st['m'] = torch.zeros_like(p) + u = (b1 * st['m'] + (1 - b1) * p.grad).sign() + lr = group['lr'] + if group['lars']: # per-tensor trust ratio: one gain line per array + lr = lr * (p.norm() / (u.norm() + 1e-12)).item() + p.mul_(1 - lr * group['weight_decay']) + p.add_(u, alpha=-lr) + st['m'].mul_(b2).add_(p.grad, alpha=1 - b2) + + +def bptt_step(blk, idx, y, T1, eps, jacreg=0.0): + xin = blk.embed(idx) + z = xin.detach().requires_grad_(True) * 0 + xin # init = embedding (keeps graph to emb) + for _ in range(T1): + z = z + eps * blk.force(z, xin, cg=True) + g = torch.autograd.grad(ce(blk, z, y), blk.allp, allow_unused=True) + gd = {id(p): gv for p, gv in zip(blk.allp, g)} + if jacreg > 0: # same soft Lyapunov penalty as ep mode (fair control) + er = torch.randn_like(z) + with torch.enable_grad(): + Jv = torch.autograd.functional.jvp(blk.nc_force, z.detach(), er, create_graph=True)[1] + R = jacreg * (Jv ** 2).sum() / (er ** 2).sum() + gr = torch.autograd.grad(R, blk.block, allow_unused=True) + for p, gv in zip(blk.block, gr): + if gv is not None: + gd[id(p)] = gv if gd.get(id(p)) is None else gd[id(p)] + gv + return gd + + +@torch.no_grad() +def evaluate(blk, T1, eps, nb=8, B=32): + tot = 0.0 + for _ in range(nb): + idx, y = get_batch('val', B, blk.T) + xin = blk.embed(idx).detach() + z = relax(blk, xin.clone(), xin, T1, eps) + tot += ce(blk, z, y).item() + return tot / nb + + +def specnorm_weight_items(blk): + items = [] + qkv = None + for name in ('WQKV', 'Wqkv', 'W_qkv', 'qkv'): + if hasattr(blk, name): + qkv = (name, getattr(blk, name)) + break + if qkv is not None: + items.append(qkv) + else: + items.extend((name, getattr(blk, name)) for name in ('WQ', 'WK', 'WV') if hasattr(blk, name)) + items.extend((name, getattr(blk, name)) for name in ('WO', 'fc', 'pj') if hasattr(blk, name)) + return [(name, w) for name, w in items if w.ndim >= 2] + + +@torch.no_grad() +def power_sigma(W, u, iters=2): + M = W.detach().reshape(W.shape[0], -1) + if u is None or u.shape != (M.shape[0],) or u.device != W.device or u.dtype != W.dtype: + u = F.normalize(torch.randn(M.shape[0], device=W.device, dtype=W.dtype), dim=0, eps=1e-12) + for _ in range(iters): + v = F.normalize(M.t().mv(u), dim=0, eps=1e-12) + u = F.normalize(M.mv(v), dim=0, eps=1e-12) + sigma = u.dot(M.mv(v)).abs() + return sigma, u.detach() + + +@torch.no_grad() +def project_specnorm_(items, cache, bound): + max_before, max_after, clamped = 0.0, 0.0, [] + for name, W in items: + sigma, u = power_sigma(W, cache.get(id(W))) + cache[id(W)] = u + before = float(sigma) + after = before + if before > bound: + scale = bound / (before + 1e-12) + W.mul_(scale) + after = bound + clamped.append(name) + max_before = max(max_before, before) + max_after = max(max_after, after) + return max_before, max_after, clamped + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--mode', choices=['ep', 'bptt'], required=True) + ap.add_argument('--steps', type=int, default=2000); ap.add_argument('--B', type=int, default=32) + ap.add_argument('--T', type=int, default=64); ap.add_argument('--C', type=int, default=128) + ap.add_argument('--H', type=int, default=4); ap.add_argument('--Mm', type=int, default=256) + ap.add_argument('--T1', type=int, default=80); ap.add_argument('--T2', type=int, default=15) + ap.add_argument('--eps', type=float, default=0.1); ap.add_argument('--beta', type=float, default=0.02) + ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--log', type=int, default=100) + ap.add_argument('--warmup', type=int, default=0) # linear lr warmup steps (big-model stability) + ap.add_argument('--state', type=str, default='') # periodic FULL-state path (weights+opt+sched+step) + ap.add_argument('--resume', action='store_true') # resume from --state if it exists (Colab timeouts) + ap.add_argument('--save_every', type=int, default=0) # full-state save cadence (0=every --log); set small on Colab + ap.add_argument('--c', type=float, default=1.0); ap.add_argument('--capx', type=float, default=3.0) + ap.add_argument('--attn_mode', choices=['real', 'energy', 'mono', 'thick'], default='real') + ap.add_argument('--ccap', type=float, default=8.0) + ap.add_argument('--specnorm', type=float, default=0.0) # hard post-step spectral projection bound; 0=off + ap.add_argument('--jacreg', type=float, default=0.0) # Bai 2021: soft Jacobian-norm (Lyapunov) penalty + ap.add_argument('--jr_max', type=float, default=16.0) # adaptive jacreg ceiling (ramps up vs residual) + ap.add_argument('--res_target', type=float, default=5e-3) # continuous controller target residual + ap.add_argument('--jr_floor', type=float, default=None) # controller floor; default=--jacreg (legacy: never off) + ap.add_argument('--res_ema', type=float, default=0.0) # EMA on residual signal (0=off); kills controller thrash + ap.add_argument('--jr_lrcouple', action='store_true') # anneal the λ floor with the lr schedule (late-drift fix) + ap.add_argument('--holo', type=int, default=0) # holomorphic EP: N circle points (0=off) + ap.add_argument('--hr', type=float, default=0.02) # holomorphic nudge radius |beta| + ap.add_argument('--pema', type=float, default=0.0) # parameter EMA decay (0=off); tames late wander + ap.add_argument('--t1max', type=int, default=0) # adaptive free phase: extend up to t1max... + ap.add_argument('--res_est', type=float, default=1e-4) # ...until this residual (estimator validity) + ap.add_argument('--t2sel', type=int, default=0) # adaptive T2: snapshot-selection cap (0=off) + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--data', type=str, default='/tmp/lt_ep/data/shakespeare_char') + ap.add_argument('--ckpt', type=str, default='') # save best weights (raw+ema) here + ap.add_argument('--corr_every', type=int, default=1) # recompute AEP corr every k nudge steps + ap.add_argument('--tf32', action='store_true') # tf32 matmuls (check res floor first!) + ap.add_argument('--abort_res', type=float, default=0.1) # kill switch: res above this 100 steps straight + ap.add_argument('--res_gate', type=float, default=0.0) # validity gate: skip task grads above this res + ap.add_argument('--wsd', type=float, default=0.0) # WSD: hold peak lr, cosine-decay only the last wsd fraction + ap.add_argument('--resreg', type=float, default=0.0) # T1-residual penalty: defend z_T1 (cap ratio vs task grad); run res_gate=0 + ap.add_argument('--eigreg', type=float, default=0.0) # #2: leading-abscissa (numerical-abscissa) control — surgical alt to jacreg + ap.add_argument('--eig_margin', type=float, default=1.0) # penalize omega(J_nc) above this (free-phase Hopf boundary ~ 1+c) + ap.add_argument('--diag_cos', type=int, default=0) # #1: every N steps, log cos(EP grad, exact BPTT grad) + res + ap.add_argument('--fingerprint', action='store_true') # load --init_ckpt, print (res,cos,abscissa,val) fingerprint, exit + ap.add_argument('--opt', choices=['adamw', 'lion', 'lionlars', 'sgdm', 'sgdsai'], default='adamw') + ap.add_argument('--wd', type=float, default=1e-4) + ap.add_argument('--fnoise', type=float, default=0.0) # optics/device twin: mult. noise per force eval + ap.add_argument('--wq_bits', type=int, default=0) # weights projected to N bits each step (0=off) + ap.add_argument('--wmis', type=float, default=0.0) # static per-device mismatch sigma (0=off) + ap.add_argument('--li_avg', type=int, default=0) # lock-in integration window (0=snapshot mode) + ap.add_argument('--navg', type=int, default=1) # restart-averaged contrast estimates per update + ap.add_argument('--track', action='store_true') # common-mode-tracking AEP correction + ap.add_argument('--rt_final', type=float, default=0.0) # anneal res_target to this (0=off), 25%-75% of run + ap.add_argument('--nudge_brake', type=float, default=0.0) # kappa: anchor spring during nudge (Tikhonov adjoint) + ap.add_argument('--init_ckpt', type=str, default='') # warm-start weights from a saved ckpt + ap.add_argument('--qknorm', action='store_true') # Qwen3-style q/k RMSNorm in attention + ap.add_argument('--compile', action='store_true') # torch.compile the free-phase relaxation (thick) + ap.add_argument('--resinit', type=float, default=1.0) # scale WO,pj at init (ReZero/Fixup: small=near-identity block) + cfg = ap.parse_args() + if cfg.specnorm < 0: + raise SystemExit("--specnorm must be non-negative") + global DD, vocab + DD = Path(cfg.data) + vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] + if cfg.tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(cfg.seed) + blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=cfg.c, attn_mode=cfg.attn_mode) + for w in blk.capw: + blk.caps[id(w)] = w.detach().norm().item() * cfg.capx + if cfg.opt in ('lion', 'lionlars'): + opt = Lion(blk.allp, lr=cfg.lr, weight_decay=cfg.wd, lars=(cfg.opt == 'lionlars')) + elif cfg.opt == 'sgdm': + opt = torch.optim.SGD(blk.allp, lr=cfg.lr, momentum=0.95, weight_decay=cfg.wd, nesterov=True) + elif cfg.opt == 'sgdsai': + # EP-SaI (SGD-SaI, arXiv:2412.11768, adapted to EP gradients): per-tensor lr from the + # init-time gradient SNR, frozen — hardware: one gain line per array, set at calibration. + gs = {id(p): [] for p in blk.allp} + for _ in range(12): + idx0, y0 = get_batch('train', cfg.B, cfg.T) + g0, _ = ep_step(blk, idx0, y0, cfg.T1, cfg.T2, cfg.eps, cfg.beta, 0.0, cfg.holo, cfg.hr, + cfg.t1max, cfg.res_est, cfg.t2sel, cfg.corr_every, 0.0) + for p in blk.allp: + if g0.get(id(p)) is not None: + gs[id(p)].append(g0[id(p)].detach().clone()) + sc = {} + for p in blk.allp: + if gs[id(p)]: + S = torch.stack(gs[id(p)]) + sc[id(p)] = (S.mean(0).norm() / (S.std(0).norm() + 1e-12)).item() + mx = max(sc.values()) + print("[sgdsai] per-tensor lr scales: " + " ".join(f"{v/mx:.3f}" for v in sc.values()), flush=True) + opt = torch.optim.SGD([dict(params=[p], lr=cfg.lr * sc.get(id(p), mx) / mx) for p in blk.allp], + momentum=0.95, weight_decay=cfg.wd, nesterov=True) + else: + opt = torch.optim.AdamW(blk.allp, lr=cfg.lr, weight_decay=cfg.wd) + if cfg.warmup > 0 or cfg.wsd > 0: # warmup -> (WSD hold peak) -> cosine decay + _w = cfg.warmup # contraction before large steps kick weights out of basin + def _lrl(s): + if s < _w: + return (s + 1) / _w + _ds = int((1 - cfg.wsd) * cfg.steps) if cfg.wsd > 0 else _w # WSD decay-start: hold peak lr until here + if s < _ds: + return 1.0 + p = (s - _ds) / max(1, cfg.steps - _ds) + return 0.05 + 0.475 * (1 + math.cos(math.pi * min(1.0, p))) + sched = torch.optim.lr_scheduler.LambdaLR(opt, _lrl) + else: + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr * 0.05) + if cfg.init_ckpt: + ckl = torch.load(cfg.init_ckpt, map_location=dev) + with torch.no_grad(): + for p, w in zip(blk.allp, ckl['allp']): + p.copy_(w.to(dev)) + print(f"[init] warm-start from {cfg.init_ckpt} (step {ckl.get('step')}, best {ckl.get('best', float('nan')):.4f})", flush=True) + xin = blk.embed(*get_batch('train', cfg.B, cfg.T)[:1]).detach() + r = (relax(blk, relax(blk, xin.clone(), xin, 200, cfg.eps), xin, 1, cfg.eps) - relax(blk, xin.clone(), xin, 200, cfg.eps)).norm().item() + print(f"[{cfg.mode}] residual~{r:.1e} | C={cfg.C} H={cfg.H} Mm={cfg.Mm} T1={cfg.T1} T2={cfg.T2}", flush=True) + best, t0, jr, rs = 9.9, time.time(), cfg.jacreg, None + pema = [p.detach().clone() for p in blk.allp] if cfg.pema > 0 else None + badct = 0 + blk.fnoise = cfg.fnoise + blk.li_avg = cfg.li_avg + blk.navg = cfg.navg + blk.track = cfg.track + blk.nbrake = cfg.nudge_brake + blk.qknorm = cfg.qknorm + if cfg.resinit != 1.0: # near-identity block at init (contractive) -> stable big-width start + with torch.no_grad(): + blk.WO.mul_(cfg.resinit); blk.pj.mul_(cfg.resinit) + spec_items = specnorm_weight_items(blk) + spec_cache = {} + if cfg.specnorm > 0: + shapes = " ".join(f"{name}{tuple(W.shape)}" for name, W in spec_items) + print(f"[specnorm] hard post-step projection: sigma_max <= {cfg.specnorm:g} on {shapes}", flush=True) + blk._cstep = None + if cfg.compile and cfg.attn_mode == 'thick': + _ee = cfg.eps + blk._cstep = torch.compile(lambda z, xin: z + _ee * blk.tforce(z, xin)) + mis = None + if cfg.wmis > 0: # fixed fabrication mismatch (same devices all run) + gm = torch.Generator().manual_seed(1234) + mis = [(1 + cfg.wmis * torch.randn(p.shape, generator=gm)).clamp(0.2, 5.0).to(dev) for p in blk.allp] + hw_on = cfg.wq_bits > 0 or mis is not None + + def hw_swap(): # measure physics on the imperfect device copy; + saved = [p.detach().clone() for p in blk.allp] # masters stay fp32 (program-verify model) + with torch.no_grad(): + for i, p in enumerate(blk.allp): + w = p * mis[i] if mis is not None else p.detach().clone() + if cfg.wq_bits > 0: + d = w.abs().max() / (2 ** (cfg.wq_bits - 1) - 1) + 1e-12 + w = torch.round(w / d) * d + p.copy_(w) + return saved + + def hw_restore(saved): + with torch.no_grad(): + for p, s in zip(blk.allp, saved): + p.copy_(s) + + start_step = 1 + if cfg.resume and cfg.state and os.path.exists(cfg.state): # Colab-timeout resume: full state + st = torch.load(cfg.state, map_location=dev) + with torch.no_grad(): + for p, w in zip(blk.allp, st['allp']): + p.copy_(w.to(dev)) + if pema is not None and st.get('pema') is not None: + pema = [s.to(dev) for s in st['pema']] + opt.load_state_dict(st['opt']); sched.load_state_dict(st['sched']) + start_step = st['step'] + 1; jr = st['jr']; rs = st['rs']; best = st['best'] + print(f"[resume] from {cfg.state}: step {start_step}, best {best:.4f}, jr {jr:.1f}", flush=True) + + def save_state(step): + if not cfg.state: + return + torch.save({'allp': [p.detach().cpu() for p in blk.allp], + 'pema': [s.cpu() for s in pema] if pema is not None else None, + 'opt': opt.state_dict(), 'sched': sched.state_dict(), + 'step': step, 'jr': jr, 'rs': rs, 'best': best}, cfg.state + '.tmp') + os.replace(cfg.state + '.tmp', cfg.state) # atomic: survive a mid-write timeout + + if cfg.fingerprint: # study s2000 vs other ckpts: print the operator's 4-D fingerprint + from diag_cos import fingerprint + fp = fingerprint(blk, cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.holo, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel) + print(f"[fingerprint] ckpt={cfg.init_ckpt or 'scratch'} | res={fp['res']:.2e} cos(EP,BPTT)={fp['cos']:.4f} " + f"num_abscissa={fp['num_abscissa']:+.4f} val={fp['val']:.4f}", flush=True) + return + for step in range(start_step, cfg.steps + 1): + idx, y = get_batch('train', cfg.B, cfg.T) + if cfg.mode == 'ep': + sw = hw_swap() if hw_on else None + grads, res = ep_step(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, jr, cfg.holo, cfg.hr, + cfg.t1max, cfg.res_est, cfg.t2sel, cfg.corr_every, cfg.res_gate, cfg.resreg, + cfg.eigreg, cfg.eig_margin) + if sw is not None: + hw_restore(sw) + if cfg.jacreg > 0: # continuous controller: drive residual -> res_target (smooth) + flo = cfg.jacreg if cfg.jr_floor is None else cfg.jr_floor + if cfg.jr_lrcouple: + flo *= sched.get_last_lr()[0] / cfg.lr + rtgt = cfg.res_target + if cfg.rt_final > 0: # stiffness anneal: tight start -> loose mid/late + u = min(1.0, max(0.0, (step / cfg.steps - 0.25) / 0.5)) + rtgt = math.exp((1 - u) * math.log(cfg.res_target) + u * math.log(cfg.rt_final)) + rs = res if rs is None else cfg.res_ema * rs + (1 - cfg.res_ema) * res + jr = min(cfg.jr_max, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / rtgt)))) + else: # damping feedback (no jacreg) + if res > 1e-3: + blk.c = min(cfg.ccap, blk.c * 1.3) + elif res < 2e-4: + blk.c = max(0.5, blk.c * 0.97) + else: + grads = bptt_step(blk, idx, y, cfg.T1, cfg.eps, jr if cfg.jacreg > 0 else 0.0) + with torch.no_grad(): # is BPTT's optimum contractive? (free-phase residual) + xinb = blk.embed(idx).detach() + zsb = relax(blk, xinb.clone(), xinb, cfg.T1, cfg.eps) + res = (relax(blk, zsb, xinb, 1, cfg.eps) - zsb).norm().item() / (zsb.norm().item() + 1e-9) + if cfg.jacreg > 0: # same residual-driven λ controller as ep mode + flo = cfg.jacreg if cfg.jr_floor is None else cfg.jr_floor + if cfg.jr_lrcouple: + flo *= sched.get_last_lr()[0] / cfg.lr + rs = res if rs is None else cfg.res_ema * rs + (1 - cfg.res_ema) * res + jr = min(cfg.jr_max, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / cfg.res_target)))) + badct = badct + 1 if (cfg.abort_res > 0 and res > cfg.abort_res) else 0 + if badct >= 100: # containment lost and not recovering: stop, keep best ckpt + print(f" ABORT at step {step}: res>{cfg.abort_res} for 100 consecutive steps (best {best:.4f})", flush=True) + break + ok = all((g is None) or torch.isfinite(g).all() for g in grads.values()) + if not ok: + print(f" step {step}: non-finite, skip", flush=True); continue + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = grads.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) + opt.step() + spec_stats = None + with torch.no_grad(): + if cfg.specnorm > 0: + spec_stats = project_specnorm_(spec_items, spec_cache, cfg.specnorm) + else: + for p in blk.capw: + pn = p.norm(); cap = blk.caps[id(p)] + if pn > cap: + p.mul_(cap / pn) + sched.step() + if spec_stats is not None and (step == start_step or step % cfg.log == 0): + sb, sa, names = spec_stats + cname = ",".join(names) if names else "none" + print(f" specnorm step {step}: max sigma before={sb:.4f} after={sa:.4f} bound={cfg.specnorm:.4f} clamped={cname}", flush=True) + if pema is not None: + with torch.no_grad(): + for s, p in zip(pema, blk.allp): + s.mul_(cfg.pema).add_(p.detach(), alpha=1 - cfg.pema) + if cfg.save_every and step % cfg.save_every == 0 and step % cfg.log != 0: + save_state(step) # mid-interval state save (Colab: cap worst-case loss) + if cfg.diag_cos and step % cfg.diag_cos == 0: # #1: gradient-alignment trajectory (scratch vs warm) + from diag_cos import cos_ep_bptt + _c, _r = cos_ep_bptt(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.holo, cfg.hr, + cfg.t1max, cfg.res_est, cfg.t2sel) + print(f" [diag] step {step}: cos(EP,BPTT)={_c:.4f} res={_r:.1e}", flush=True) + if step % cfg.log == 0: + prevb = best + sw = hw_swap() if hw_on else None + v = evaluate(blk, cfg.T1, cfg.eps) + if sw is not None: + hw_restore(sw) + best = min(best, v) + etag = "" + if pema is not None: + with torch.no_grad(): + raw = [p.detach().clone() for p in blk.allp] + for p, s in zip(blk.allp, pema): + p.copy_(s) + ve = evaluate(blk, cfg.T1, cfg.eps) + for p, r in zip(blk.allp, raw): + p.copy_(r) + best = min(best, ve); etag = f" ema={ve:.4f}" + if cfg.ckpt and best < prevb: + torch.save({'allp': [p.detach().cpu() for p in blk.allp], + 'pema': [s.cpu() for s in pema] if pema is not None else None, + 'step': step, 'best': best}, cfg.ckpt) + print(f"step {step:4d}/{cfg.steps} | val CE {v:.4f}{etag} (best {best:.4f}) | jr={jr:.1f} res={res:.1e} | {step/(time.time()-t0):.2f} it/s", flush=True) + save_state(step) # full-state checkpoint each log interval (Colab resume) + print(f"[{cfg.mode}] DONE best val CE {best:.4f} (random baseline ln({vocab})={math.log(vocab):.3f})", flush=True) + out_dir = Path('runs') + out_dir.mkdir(exist_ok=True) + json.dump({'mode': cfg.mode, 'best_val_ce': best}, open(out_dir / f'H2_{cfg.mode}.json', 'w')) + + +if __name__ == '__main__': + main() diff --git a/ep_run/mdpi_paper.html b/ep_run/mdpi_paper.html new file mode 100644 index 0000000..13341aa --- /dev/null +++ b/ep_run/mdpi_paper.html @@ -0,0 +1,10 @@ +<HTML><HEAD> +<TITLE>Access Denied</TITLE> +</HEAD><BODY> +<H1>Access Denied</H1> + +You don't have permission to access "http://www.mdpi.com/2072-666X/14/7/1367" on this server.<P> +Reference #18.8a1c2117.1782048650.540b0bef +<P>https://errors.edgesuite.net/18.8a1c2117.1782048650.540b0bef</P> +</BODY> +</HTML> diff --git a/ep_run/model.py b/ep_run/model.py new file mode 100644 index 0000000..149724b --- /dev/null +++ b/ep_run/model.py @@ -0,0 +1,156 @@ +"""Tiny GPT with switchable softmax/sigmoid causal self-attention. + +Architecture follows nanoGPT (Karpathy), trimmed to a single file for this +Sigmoid Attention reproduction experiment (Ramapuram et al. 2024). +""" +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclass +class GPTConfig: + block_size: int = 256 + vocab_size: int = 65 + n_layer: int = 6 + n_head: int = 6 + n_embd: int = 384 + dropout: float = 0.2 + bias: bool = False # bias in linear layers + attn_mode: str = "softmax" # "softmax" or "sigmoid" + sigmoid_bias_mode: str = "neg_log_n" # "zero" | "neg_log_n" | "learned" + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: GPTConfig): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.n_head = config.n_head + self.n_embd = config.n_embd + self.head_dim = config.n_embd // config.n_head + self.block_size = config.block_size + self.attn_mode = config.attn_mode + self.sigmoid_bias_mode = config.sigmoid_bias_mode + + self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.attn_drop = nn.Dropout(config.dropout) + self.resid_drop = nn.Dropout(config.dropout) + + causal = torch.tril(torch.ones(config.block_size, config.block_size, dtype=torch.bool)) + self.register_buffer("causal_mask", causal, persistent=False) + + if config.attn_mode == "sigmoid": + if config.sigmoid_bias_mode == "zero": + init_b = 0.0 + else: + init_b = -math.log(config.block_size) + if config.sigmoid_bias_mode == "learned": + self.sig_bias = nn.Parameter(torch.tensor(init_b)) + else: + self.register_buffer("sig_bias", torch.tensor(init_b), persistent=False) + + def forward(self, x): + B, T, C = x.shape + q, k, v = self.qkv(x).split(self.n_embd, dim=-1) + q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) + k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) + v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) + + scores = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) + mask = self.causal_mask[:T, :T] + scores = scores.masked_fill(~mask, float("-inf")) + + if self.attn_mode == "softmax": + attn = F.softmax(scores, dim=-1) + else: + # sigmoid(scores + b). masked -> sigmoid(-inf) = 0 naturally. + attn = torch.sigmoid(scores + self.sig_bias) + + attn = self.attn_drop(attn) + out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C) + return self.resid_drop(self.proj(out)) + + +class MLP(nn.Module): + def __init__(self, config: GPTConfig): + super().__init__() + self.fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.drop = nn.Dropout(config.dropout) + + def forward(self, x): + return self.drop(self.proj(F.gelu(self.fc(x)))) + + +class Block(nn.Module): + def __init__(self, config: GPTConfig): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.ln2 = nn.LayerNorm(config.n_embd) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + + +class GPT(nn.Module): + def __init__(self, config: GPTConfig): + super().__init__() + self.config = config + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Embedding(config.block_size, config.n_embd) + self.drop = nn.Dropout(config.dropout) + self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("proj.weight"): + nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + + def num_params(self) -> int: + return sum(p.numel() for p in self.parameters()) + + def forward(self, idx, targets=None): + B, T = idx.shape + assert T <= self.config.block_size, f"seq len {T} > block_size {self.config.block_size}" + pos = torch.arange(T, device=idx.device) + x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) + for blk in self.blocks: + x = blk(x) + x = self.ln_f(x) + logits = self.head(x) + if targets is None: + return logits, None + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + return logits, loss + + @torch.no_grad() + def generate(self, idx, max_new_tokens: int, temperature: float = 1.0, top_k=None): + for _ in range(max_new_tokens): + idx_cond = idx[:, -self.config.block_size :] + logits, _ = self(idx_cond) + logits = logits[:, -1, :] / temperature + if top_k is not None: + v, _ = torch.topk(logits, top_k) + logits[logits < v[:, [-1]]] = -float("inf") + probs = F.softmax(logits, dim=-1) + nxt = torch.multinomial(probs, 1) + idx = torch.cat([idx, nxt], dim=1) + return idx diff --git a/ep_run/model_local.py b/ep_run/model_local.py new file mode 100644 index 0000000..a84c692 --- /dev/null +++ b/ep_run/model_local.py @@ -0,0 +1,470 @@ +"""Sigmoid GPT with split Q/K/V projections and LocalLinear for method dispatch. + +Derived from model.py but uses LocalLinear for every linear layer and splits +the fused qkv into separate q_proj, k_proj, v_proj so that each projection has +its own feedback matrix for FA / DFA / sign_sym. +""" +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from local_layers import LocalLinear + + +class SigmoidSTE(torch.autograd.Function): + """Sigmoid forward, straight-through backward (skip A(1-A) derivative).""" + @staticmethod + def forward(ctx, x): + return torch.sigmoid(x) + @staticmethod + def backward(ctx, grad_out): + return grad_out + + +class GELUSTE(torch.autograd.Function): + """GELU forward, straight-through backward (skip gelu' derivative).""" + @staticmethod + def forward(ctx, x): + return F.gelu(x) + @staticmethod + def backward(ctx, grad_out): + return grad_out + + +class HardTopK(torch.autograd.Function): + """k-WTA: zero out all but top-k (by abs value) along last dim, in BOTH forward and backward. + + Forward: keep top-k entries, zero rest. + Backward: gradient mask = forward mask (only winners get gradient). + + This enforces strict sparsity — non-selected channels never update. + """ + @staticmethod + def forward(ctx, x, k): + topk_vals, topk_idx = x.abs().topk(k, dim=-1) + mask = torch.zeros_like(x).scatter_(-1, topk_idx, 1.0) + ctx.save_for_backward(mask) + return x * mask + + @staticmethod + def backward(ctx, grad_out): + (mask,) = ctx.saved_tensors + return grad_out * mask, None + + +class FrozenSubspace(nn.Module): + """Project h to fixed r-dim orthonormal subspace via Q Q^T h. + + Q ∈ R^{d × r} is a random orthonormal basis, frozen at init. + Output lives in span(Q) ⊂ R^d. (d-r) directions are killed. + + With same seed across blocks, all layers share the same subspace — + so the residual stream is constrained to span(Q) throughout the network. + + Differentiable (no STE): grad_h = Q Q^T grad_out (same projection). + For BPfree: residual codebook subspace ≈ span(Q) is exactly what BPfree + delivers gradient on, so feedback geometry is matched by construction. + """ + def __init__(self, d_model, rank, seed=42): + super().__init__() + self.rank = rank + gen = torch.Generator() + gen.manual_seed(seed) + Q, _ = torch.linalg.qr(torch.randn(d_model, rank, generator=gen)) + self.register_buffer("Q", Q) # (d, r) + + def forward(self, h): + # h @ Q → (..., r) coefficients in basis + # @ Q.t() → back to (..., d), now in span(Q) + return h @ self.Q @ self.Q.t() + + +class VQResidualDir(nn.Module): + """Directional quantization to fixed codebook with STE backward. + + Forward: replace h's direction with nearest of K fixed unit-norm codebook entries + (per token). Magnitude is preserved. + Backward: identity through h (STE — no gradient on the codebook lookup itself). + + Codebook is initialized with random unit-norm directions and FROZEN (registered buffer). + The "feature directions" are predefined — the network only learns *which code to land on* + per token per layer. Discrete bottleneck: log2(K) bits per token per layer. + + For BPfree: the gradient signal needed to switch between codes is in + {radial, low-rank residual} subspace, matching BPfree exit's bandwidth. + """ + def __init__(self, d_model, n_codes, seed=None): + super().__init__() + self.n_codes = n_codes + gen = torch.Generator() + if seed is not None: + gen.manual_seed(seed) + codes = torch.randn(n_codes, d_model, generator=gen) + codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-8) + self.register_buffer("codebook", codes) + + def forward(self, h): + h_norm = h.norm(dim=-1, keepdim=True).clamp_min(1e-8) + h_hat = h / h_norm + sims = h_hat @ self.codebook.t() # (..., K) + idx = sims.argmax(dim=-1) + z_q_dir = self.codebook[idx] # (..., d) unit-norm direction + z_q = z_q_dir * h_norm # restore magnitude + # STE: forward = z_q, backward = identity through h + return h + (z_q - h).detach() + + +class LayerNormSTE(nn.Module): + """LayerNorm forward, straight-through backward (gradient passes through as identity).""" + def __init__(self, normalized_shape): + super().__init__() + self.ln = nn.LayerNorm(normalized_shape) + def forward(self, x): + with torch.no_grad(): + out = self.ln(x) + return x + (out - x).detach() + + +class _ProjectedSurrogateLNFn(torch.autograd.Function): + """Core autograd function for projected surrogate LN backward. + mode='projected': full P_z(v) = v - mean(v) - z*mean(v*z), scaled by 1/σ + mode='center_scale': only v - mean(v), scaled by 1/σ (no radial removal) + """ + @staticmethod + def forward(ctx, x, eps, mode): + x_f = x.float() if x.dtype in (torch.float16, torch.bfloat16) else x + mu = x_f.mean(dim=-1, keepdim=True) + xc = x_f - mu + var = (xc * xc).mean(dim=-1, keepdim=True) + rsigma = torch.rsqrt(var + eps) + z = xc * rsigma + ctx.save_for_backward(z, rsigma) + ctx.mode = mode + ctx.input_dtype = x.dtype + return z.to(dtype=x.dtype) + + @staticmethod + def backward(ctx, g_tilde): + z, rsigma = ctx.saved_tensors + v = g_tilde.float() if g_tilde.dtype in (torch.float16, torch.bfloat16) else g_tilde + v = v.to(dtype=z.dtype) + v_mean = v.mean(dim=-1, keepdim=True) + if ctx.mode == "projected": + vz_mean = (v * z).mean(dim=-1, keepdim=True) + p_v = v - v_mean - z * vz_mean + else: # center_scale + p_v = v - v_mean + g_x = p_v * rsigma + return g_x.to(dtype=ctx.input_dtype), None, None + + +class LayerNormProjectedSurrogate(nn.Module): + """LN forward = standard normalization. LN backward = projected surrogate (not BP). + mode='projected': full mean-center + radial removal + 1/σ scaling + mode='center_scale': mean-center + 1/σ only (no radial removal) + Affine (γ, β) handled outside the custom Function so g̃ = ∂L/∂z exactly. + """ + def __init__(self, normalized_shape, eps=1e-5, mode="projected", + elementwise_affine=False, bias=True): + super().__init__() + self.normalized_shape = normalized_shape + self.eps = eps + self.mode = mode + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias_param = nn.Parameter(torch.zeros(normalized_shape)) if bias else None + else: + self.weight = None + self.bias_param = None + + def forward(self, x): + z = _ProjectedSurrogateLNFn.apply(x, self.eps, self.mode) + if self.weight is not None: + z = z * self.weight + if self.bias_param is not None: + z = z + self.bias_param + return z + + +class SoftmaxValueMixLocalFn(torch.autograd.Function): + """Fused softmax(S) @ V with local backward. + + Forward: A = softmax(S), O = A @ V + Backward: g_S_{i,j} = A_{ij} * <δO_i, V_j - O_i> (no lateral sum!) + δV = A^T @ δO (attention-weighted gather) + + The softmax Jacobian's "lateral sum" Σ_j A_ij g_ij collapses to a per-query + scalar baseline <δO_i, O_i> when composed with A@V — pure algebra, not approximation. + """ + @staticmethod + def forward(ctx, scores, v): + attn = F.softmax(scores, dim=-1) + out = torch.einsum("bhtk,bhkd->bhtd", attn, v) + ctx.save_for_backward(attn.detach(), out.detach(), v.detach()) + return out + + @staticmethod + def backward(ctx, delta_out): + attn, out, v = ctx.saved_tensors + # g_A_{i,j} = <δO_i, V_j> + g_a = torch.einsum("bhtd,bhkd->bhtk", delta_out, v) + # baseline = <δO_i, O_i> per query (the "lateral sum" collapsed to this) + baseline = (delta_out * out).sum(dim=-1, keepdim=True) + # g_S_{i,j} = A_{ij} * (<δO_i, V_j> - <δO_i, O_i>) + g_scores = attn * (g_a - baseline) + # δV = A^T @ δO (value gradient) + delta_v = torch.einsum("bhtk,bhtd->bhkd", attn, delta_out) + return g_scores, delta_v + + +@dataclass +class LocalGPTConfig: + block_size: int = 256 + vocab_size: int = 65 + n_layer: int = 6 + n_head: int = 6 + n_embd: int = 384 + dropout: float = 0.2 + bias: bool = False + attn_mode: str = "sigmoid" + sigmoid_bias_mode: str = "neg_log_n" + method: str = "bp" # bp | fa | dfa | sign_sym + # STE ablation flags + ste_sigmoid: bool = False # skip A(1-A) in sigmoid attention backward + ste_gelu: bool = False # skip gelu' in FFN backward + freeze_emb: bool = False # freeze token + position embeddings + # LN backward mode: "bp" (standard), "ste" (identity), "center_scale", "projected" + ln_mode: str = "bp" + fuse_attn_local: bool = False # fuse softmax+A@V with local backward (no lateral sum) + # Sparsity options for SparseFormer experiments + mlp_topk: int = 0 # if > 0, apply hard top-k (k-WTA) to MLP hidden activation (4*n_embd dim) + resid_topk: int = 0 # if > 0, apply hard top-k to residual stream output of each block (n_embd dim) + # FrozenCodeFormer: directional VQ to fixed codebook at residual stream end + vq_codes: int = 0 # if > 0, apply VQResidualDir with K=vq_codes fixed unit-norm codebook entries + # FrozenSubspace: continuous r-dim subspace constraint (shared Q across all blocks) + subspace_rank: int = 0 # if > 0, project residual stream to fixed r-dim subspace at each block + # FA B-init mode (only used when method='fa'): gaussian | orthogonal | ortho_he | sparse + fa_init_mode: str = "gaussian" + fa_sparse_k: int = 0 # for fa_init_mode='sparse': non-zero entries per row (0 = auto in/16) + # GrAPE: per-step JVP-based cosine alignment of B toward true Jacobian (forward-only, no W^T) + fa_grape: bool = False + fa_grape_n_probe: int = 32 # batch size for JVP probes + # Path IV: learned per-block residual gates. Each block: x + α_attn·attn(x) + α_mlp·mlp(x) + gated_blocks: bool = False # if True, add learnable scalar gate per (block, sublayer) + + +class LocalCausalSelfAttention(nn.Module): + def __init__(self, config: LocalGPTConfig): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.n_head = config.n_head + self.n_embd = config.n_embd + self.head_dim = config.n_embd // config.n_head + self.block_size = config.block_size + self.attn_mode = config.attn_mode + self.ste_sigmoid = config.ste_sigmoid + self.fuse_attn_local = config.fuse_attn_local + + self.q_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.k_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.v_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.o_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.attn_drop = nn.Dropout(config.dropout) + self.resid_drop = nn.Dropout(config.dropout) + + causal = torch.tril(torch.ones(config.block_size, config.block_size, dtype=torch.bool)) + self.register_buffer("causal_mask", causal, persistent=False) + + if config.attn_mode == "sigmoid": + init_b = 0.0 if config.sigmoid_bias_mode == "zero" else -math.log(config.block_size) + if config.sigmoid_bias_mode == "learned": + self.sig_bias = nn.Parameter(torch.tensor(init_b)) + else: + self.register_buffer("sig_bias", torch.tensor(init_b), persistent=False) + + def forward(self, x): + B, T, C = x.shape + q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) + + scores = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) + mask = self.causal_mask[:T, :T] + scores = scores.masked_fill(~mask, float("-inf")) + + if self.fuse_attn_local and self.attn_mode == "softmax": + # Fused softmax+A@V with local backward: + # g_S_{i,j} = A_{ij} * <δO_i, V_j - O_i> (no lateral sum) + out = SoftmaxValueMixLocalFn.apply(scores, v) + else: + if self.attn_mode == "softmax": + attn = F.softmax(scores, dim=-1) + elif self.ste_sigmoid: + attn = SigmoidSTE.apply(scores + self.sig_bias) + else: + attn = torch.sigmoid(scores + self.sig_bias) + attn = self.attn_drop(attn) + out = attn @ v + + out = out.transpose(1, 2).contiguous().view(B, T, C) + return self.resid_drop(self.o_proj(out)) + + +class LocalMLP(nn.Module): + def __init__(self, config: LocalGPTConfig): + super().__init__() + self.fc = LocalLinear(config.n_embd, 4 * config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.proj = LocalLinear(4 * config.n_embd, config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.drop = nn.Dropout(config.dropout) + self.ste_gelu = config.ste_gelu + self.mlp_topk = config.mlp_topk + + def forward(self, x): + h = self.fc(x) + if self.ste_gelu: + h = GELUSTE.apply(h) + else: + h = F.gelu(h) + if self.mlp_topk > 0: + h = HardTopK.apply(h, self.mlp_topk) + return self.drop(self.proj(h)) + + +def _make_ln(config): + """Build the right LN variant based on config.ln_mode.""" + if config.ln_mode == "bp": + return nn.LayerNorm(config.n_embd) + if config.ln_mode == "ste": + return LayerNormSTE(config.n_embd) + if config.ln_mode in ("center_scale", "projected"): + return LayerNormProjectedSurrogate( + config.n_embd, mode=config.ln_mode, elementwise_affine=True, + ) + raise ValueError(f"Unknown ln_mode: {config.ln_mode}") + + +class LocalBlock(nn.Module): + def __init__(self, config: LocalGPTConfig): + super().__init__() + self.ln1 = _make_ln(config) + self.ln2 = _make_ln(config) + self.attn = LocalCausalSelfAttention(config) + self.mlp = LocalMLP(config) + self.resid_topk = config.resid_topk + self.vq = VQResidualDir(config.n_embd, config.vq_codes) if config.vq_codes > 0 else None + # FrozenSubspace uses fixed seed=42 so all blocks share the same Q (same subspace). + self.subspace = FrozenSubspace(config.n_embd, config.subspace_rank, seed=42) \ + if config.subspace_rank > 0 else None + # Path IV: per-sublayer learned residual gates. Init to 1.0 (no initial gating). + # If a sublayer is "noise net" under BPfree, its α can drive toward 0. + if config.gated_blocks: + self.alpha_attn = nn.Parameter(torch.ones(1)) + self.alpha_mlp = nn.Parameter(torch.ones(1)) + else: + self.alpha_attn = None + self.alpha_mlp = None + + def forward(self, x): + if self.alpha_attn is not None: + x = x + self.alpha_attn * self.attn(self.ln1(x)) + x = x + self.alpha_mlp * self.mlp(self.ln2(x)) + else: + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + if self.resid_topk > 0: + x = HardTopK.apply(x, self.resid_topk) + if self.vq is not None: + x = self.vq(x) + if self.subspace is not None: + x = self.subspace(x) + return x + + +class LocalGPT(nn.Module): + def __init__(self, config: LocalGPTConfig): + super().__init__() + self.config = config + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Embedding(config.block_size, config.n_embd) + if config.freeze_emb: + self.tok_emb.weight.requires_grad_(False) + self.pos_emb.weight.requires_grad_(False) + self.drop = nn.Dropout(config.dropout) + self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)]) + self.ln_f = _make_ln(config) + # Output head: also a LocalLinear (last linear layer before logits) + self.head = LocalLinear(config.n_embd, config.vocab_size, bias=False, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None)) + + self.apply(self._init_weights) + # Scale projection weights to reduce residual stream growth + for pn, p in self.named_parameters(): + if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"): + nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, LocalLinear)): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + if getattr(m, "bias", None) is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + + def num_params(self) -> int: + return sum(p.numel() for p in self.parameters()) + + def forward(self, idx, targets=None): + B, T = idx.shape + assert T <= self.config.block_size + pos = torch.arange(T, device=idx.device) + x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) + for blk in self.blocks: + x = blk(x) + x = self.ln_f(x) + logits = self.head(x) + if targets is None: + return logits, None + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + return logits, loss + + @torch.no_grad() + def generate(self, idx, max_new_tokens: int, temperature: float = 1.0, top_k=None): + for _ in range(max_new_tokens): + idx_cond = idx[:, -self.config.block_size :] + logits, _ = self(idx_cond) + logits = logits[:, -1, :] / temperature + if top_k is not None: + v, _ = torch.topk(logits, top_k) + logits[logits < v[:, [-1]]] = -float("inf") + probs = F.softmax(logits, dim=-1) + nxt = torch.multinomial(probs, 1) + idx = torch.cat([idx, nxt], dim=1) + return idx diff --git a/ep_run/oracle_adjoint_train.py b/ep_run/oracle_adjoint_train.py new file mode 100644 index 0000000..0503a66 --- /dev/null +++ b/ep_run/oracle_adjoint_train.py @@ -0,0 +1,368 @@ +"""Oracle exact-equilibrium-adjoint training from a redx pre-drift checkpoint. + +This deliberately trains the equilibrium objective L(z*) with an exact +matrix-free implicit adjoint: + + F_z(z*)^T lambda = -L_z(z*) + dL/dtheta = L_theta + lambda^T F_theta + +Block-parameter adjoints reuse asym_probe.py. The readout head keeps the same +local dCE/dWh path used by lt_ep_train.py. +""" + +import argparse +import math +import os +import pickle +import time +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import torch + +import lt_ep_train as L +from asym_probe import ( + Operators, + block_param_list, + ce_state_grad, + cos, + exact_transpose_grad, + flat_grad_by_param_id, + norm, + set_param_requires_grad, + solve_exact_adjoint, +) +from lt_ep_train import EQBlock, ce, relax + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt", default="runs/redx_traj/s2000.pt") + ap.add_argument("--data", default="data/tinystories_bpe") + ap.add_argument("--log-file", default="runs/oracle_adjoint.log") + ap.add_argument("--save", default="runs/oracle_adjoint.pt") + ap.add_argument("--device", default="cuda", choices=["cuda", "cpu"]) + ap.add_argument("--steps", type=int, default=1500) + ap.add_argument("--B", type=int, default=24) + ap.add_argument("--T", type=int, default=256) + ap.add_argument("--C", type=int, default=512) + ap.add_argument("--H", type=int, default=16) + ap.add_argument("--Mm", type=int, default=256) + ap.add_argument("--T1", type=int, default=150) + ap.add_argument("--eps", type=float, default=0.1) + ap.add_argument("--lr", type=float, default=6e-4) + ap.add_argument("--wd", type=float, default=1e-4) + ap.add_argument("--wsd", type=float, default=0.2) + ap.add_argument("--warmup", type=int, default=0) + ap.add_argument("--log-every", type=int, default=50) + ap.add_argument("--eval-batches", type=int, default=8) + ap.add_argument("--eval-B", type=int, default=32) + ap.add_argument("--rho-B", type=int, default=8) + ap.add_argument("--rho-steps", type=int, default=800) + ap.add_argument("--res-est", type=float, default=1e-5) + ap.add_argument("--t1max", type=int, default=6000) + ap.add_argument("--relax-chunk", type=int, default=50) + ap.add_argument("--abort-res", type=float, default=0.3) + ap.add_argument("--grad-clip", type=float, default=5.0) + ap.add_argument("--capx", type=float, default=3.0) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--adjoint-iters", type=int, default=200) + ap.add_argument("--adjoint-tol", type=float, default=1e-5) + ap.add_argument("--adjoint-mu", type=float, default=1e-4) + ap.add_argument("--solve-iters", type=int, default=80) + ap.add_argument("--solve-tol", type=float, default=1e-5) + ap.add_argument("--sanity-cos-min", type=float, default=0.999) + ap.add_argument("--tf32", action="store_true") + return ap.parse_args() + + +def require_cuda(device): + if device != "cuda": + return + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + torch.cuda.set_device(0) + return + raise SystemExit( + "ERROR: CUDA unavailable for requested GPU0 run; " + f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r}" + ) + + +def resolve_path(path): + p = Path(path) + if p.is_absolute(): + return p + return Path.cwd() / p + + +def configure_globals(cfg, dev): + L.dev = dev + L.DD = resolve_path(cfg.data) + L.vocab = pickle.load(open(L.DD / "meta.pkl", "rb"))["vocab_size"] + torch.backends.cuda.matmul.allow_tf32 = bool(cfg.tf32) + torch.backends.cudnn.allow_tf32 = bool(cfg.tf32) + + +def build_block(cfg, dev): + torch.manual_seed(cfg.seed) + blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=1.0, attn_mode="thick") + for w in blk.capw: + blk.caps[id(w)] = w.detach().norm().item() * cfg.capx + blk.qknorm = True + blk.fnoise = 0.0 + blk._cstep = None + blk.navg = 1 + blk.li_avg = 0 + blk.track = True + blk.nbrake = 0.0 + ckpt_path = resolve_path(cfg.ckpt) + ck = torch.load(ckpt_path, map_location=dev) + with torch.no_grad(): + for p, w in zip(blk.allp, ck["allp"]): + p.copy_(w.to(dev)) + return blk, ck, ckpt_path + + +@torch.no_grad() +def one_step_residual(blk, z, xin, eps): + z1 = relax(blk, z, xin, 1, eps) + return (z1 - z).norm().item() / (z.norm().item() + 1e-12) + + +def relax_refine(blk, xin, cfg): + z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) + finite_t1_res = one_step_residual(blk, z, xin, cfg.eps) + res = finite_t1_res + steps = cfg.T1 + while steps < cfg.t1max and res > cfg.res_est: + chunk = min(cfg.relax_chunk, cfg.t1max - steps) + z = relax(blk, z, xin, chunk, cfg.eps) + steps += chunk + res = one_step_residual(blk, z, xin, cfg.eps) + if not math.isfinite(res): + break + return z.detach(), finite_t1_res, res, steps + + +def oracle_grad(blk, idx, y, cfg): + xin0 = blk.embed(idx).detach() + zstar, finite_t1_res, zstar_res, relax_steps = relax_refine(blk, xin0, cfg) + + set_param_requires_grad(blk, False) + op = Operators(blk, zstar, xin0, cfg, mu=0.0) + ell, loss_zstar = ce_state_grad(blk, zstar, y) + lam, gmres_rel, gmres_info, gmres_iters, adj_mu = solve_exact_adjoint(op, ell, cfg) + if adj_mu != 0.0 or gmres_info != 0 or (not math.isfinite(gmres_rel)) or gmres_rel > max(10.0 * cfg.adjoint_tol, 1e-4): + set_param_requires_grad(blk, True) + raise RuntimeError( + "exact adjoint GMRES failed " + f"(rel={gmres_rel:.3e}, info={gmres_info}, iters={gmres_iters}, tikhonov_mu={adj_mu:.3e})" + ) + + params = block_param_list(blk) + block_grads = exact_transpose_grad(blk, idx, zstar, xin0, lam, params) + grads = dict(block_grads) + with torch.enable_grad(): + (gh,) = torch.autograd.grad(ce(blk, zstar.detach(), y), blk.Wh) + grads[id(blk.Wh)] = gh + set_param_requires_grad(blk, True) + + block_flat = flat_grad_by_param_id(grads, params) + gt_flat = flat_grad_by_param_id(block_grads, params) + sanity_cos = cos(block_flat, gt_flat) + head_norm = norm(gh.detach()).item() + block_norm = norm(block_flat).item() + return grads, { + "loss_zstar": loss_zstar, + "finite_t1_res": finite_t1_res, + "zstar_res": zstar_res, + "relax_steps": relax_steps, + "gmres_rel": gmres_rel, + "gmres_info": gmres_info, + "gmres_iters": gmres_iters, + "adj_mu": adj_mu, + "sanity_cos": sanity_cos, + "block_grad_norm": block_norm, + "head_grad_norm": head_norm, + } + + +def make_optimizer_and_sched(blk, cfg): + opt = torch.optim.AdamW(blk.allp, lr=cfg.lr, weight_decay=cfg.wd) + + def lr_lambda(step): + if cfg.warmup > 0 and step < cfg.warmup: + return (step + 1) / cfg.warmup + decay_start = int((1.0 - cfg.wsd) * cfg.steps) if cfg.wsd > 0 else cfg.warmup + if step < decay_start: + return 1.0 + p = (step - decay_start) / max(1, cfg.steps - decay_start) + return 0.05 + 0.475 * (1.0 + math.cos(math.pi * min(1.0, p))) + + sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda) + return opt, sched + + +@torch.no_grad() +def apply_weight_caps(blk): + for p in blk.capw: + pn = p.norm() + cap = blk.caps[id(p)] + if pn > cap: + p.mul_(cap / pn) + + +@torch.no_grad() +def evaluate_ce(blk, cfg): + total = 0.0 + for _ in range(cfg.eval_batches): + idx, y = L.get_batch("val", cfg.eval_B, cfg.T) + xin = blk.embed(idx).detach() + z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) + total += ce(blk, z, y).item() + return total / max(1, cfg.eval_batches) + + +@torch.no_grad() +def finite_residual_on_batch(blk, idx, cfg): + xin = blk.embed(idx).detach() + z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) + return one_step_residual(blk, z, xin, cfg.eps) + + +@torch.no_grad() +def rho_decay_probe(blk, idx, cfg): + xin = blk.embed(idx).detach() + z = xin.clone() + residuals = [] + for _ in range(cfg.rho_steps): + z2 = z + cfg.eps * blk.force(z, xin).detach() + r = (z2 - z).norm().item() / (z.norm().item() + 1e-12) + residuals.append(r) + z = z2 + if (not math.isfinite(r)) or r > 1e2: + break + window = [r for r in residuals if 1e-6 < r < 1e-1] or residuals[-200:] + ratios = [window[i + 1] / window[i] for i in range(len(window) - 1) if window[i] > 0 and window[i + 1] > 0] + rho = math.exp(sum(math.log(x) for x in ratios) / len(ratios)) if ratios else float("nan") + return rho, residuals[-1] if residuals else float("nan"), len(residuals) + + +def log_line(path, line): + with open(path, "a", encoding="utf-8") as f: + f.write(line + "\n") + print(line, flush=True) + + +def track(blk, cfg, fixed_idx, step, info, t0): + val = evaluate_ce(blk, cfg) + val_res = finite_residual_on_batch(blk, fixed_idx, cfg) + rho, rho_final, rho_n = rho_decay_probe(blk, fixed_idx, cfg) + lr = info.get("lr", float("nan")) + line = ( + f"step {step:4d}/{cfg.steps} | val CE {val:.4f} | finite_T1_res {val_res:.3e} " + f"| rho800 {rho:.4f} final_res {rho_final:.2e} n={rho_n} " + f"| train_res {info.get('finite_t1_res', float('nan')):.3e} " + f"| zstar_res {info.get('zstar_res', float('nan')):.3e} relax {info.get('relax_steps', -1)} " + f"| gmres {info.get('gmres_rel', float('nan')):.2e}/{info.get('gmres_iters', -1)} " + f"| lr {lr:.3e} | {max(step, 1) / max(time.time() - t0, 1e-9):.4f} it/s" + ) + log_line(cfg.log_file, line) + return val, val_res, rho + + +def save_ckpt(blk, cfg, step, best): + if not cfg.save: + return + path = resolve_path(cfg.save) + path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + {"allp": [p.detach().cpu() for p in blk.allp], "step": step, "best": best}, + str(path) + ".tmp", + ) + os.replace(str(path) + ".tmp", path) + + +def main(): + cfg = parse_args() + require_cuda(cfg.device) + dev = torch.device("cuda:0" if cfg.device == "cuda" else "cpu") + configure_globals(cfg, dev) + cfg.log_file = str(resolve_path(cfg.log_file)) + Path(cfg.log_file).parent.mkdir(parents=True, exist_ok=True) + cfg_for_op = SimpleNamespace(**vars(cfg)) + + blk, ck, ckpt_path = build_block(cfg_for_op, dev) + opt, sched = make_optimizer_and_sched(blk, cfg_for_op) + + torch.manual_seed(1234) + fixed_idx, _ = L.get_batch("val", cfg.rho_B, cfg.T) + torch.manual_seed(cfg.seed + 1) + + header = ( + f"# oracle_adjoint_train device={dev} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')!r} " + f"ckpt={ckpt_path} ckpt_step={ck.get('step')} ckpt_best={ck.get('best')} " + f"B={cfg.B} T={cfg.T} C={cfg.C} H={cfg.H} Mm={cfg.Mm} T1={cfg.T1} " + f"lr={cfg.lr} wd={cfg.wd} wsd={cfg.wsd} res_est={cfg.res_est} t1max={cfg.t1max}" + ) + with open(cfg.log_file, "w", encoding="utf-8") as f: + f.write(header + "\n") + print(header, flush=True) + + t0 = time.time() + idx0, y0 = L.get_batch("train", cfg.B, cfg.T) + grads0, info0 = oracle_grad(blk, idx0, y0, cfg_for_op) + sanity = info0["sanity_cos"] + sanity_line = ( + f"step 0 sanity: cos(oracle_block_grad, asym_probe g_transpose)={sanity:+.6f} " + f"gmres_rel={info0['gmres_rel']:.3e} gmres_iters={info0['gmres_iters']} " + f"zstar_res={info0['zstar_res']:.3e} finite_T1_res={info0['finite_t1_res']:.3e}" + ) + log_line(cfg.log_file, sanity_line) + if (not math.isfinite(sanity)) or sanity < cfg.sanity_cos_min: + bug = f"STOP: step-0 oracle/asym_probe sanity cosine {sanity:+.6f} < {cfg.sanity_cos_min:.6f}" + log_line(cfg.log_file, bug) + raise SystemExit(3) + + info0["lr"] = sched.get_last_lr()[0] + best, _, _ = track(blk, cfg_for_op, fixed_idx, 0, info0, t0) + save_ckpt(blk, cfg_for_op, 0, best) + + for step in range(1, cfg.steps + 1): + idx, y = L.get_batch("train", cfg.B, cfg.T) + try: + grads, info = oracle_grad(blk, idx, y, cfg_for_op) + except RuntimeError as err: + log_line(cfg.log_file, f"ABORT step {step}: {err}") + break + if info["finite_t1_res"] > cfg.abort_res: + log_line( + cfg.log_file, + f"ABORT step {step}: finite_T1_res {info['finite_t1_res']:.3e} > {cfg.abort_res:.3e}", + ) + break + if not all((g is None) or torch.isfinite(g).all() for g in grads.values()): + log_line(cfg.log_file, f"ABORT step {step}: non-finite oracle gradient") + break + + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = grads.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, cfg.grad_clip) + opt.step() + apply_weight_caps(blk) + sched.step() + + if step % cfg.log_every == 0: + info["lr"] = sched.get_last_lr()[0] + val, _, _ = track(blk, cfg_for_op, fixed_idx, step, info, t0) + best = min(best, val) + save_ckpt(blk, cfg_for_op, step, best) + + save_ckpt(blk, cfg_for_op, step if cfg.steps > 0 else 0, best) + log_line(cfg.log_file, f"DONE best_val_CE={best:.4f}") + + +if __name__ == "__main__": + main() diff --git a/ep_run/prepare_tinystories.py b/ep_run/prepare_tinystories.py new file mode 100644 index 0000000..d7305a3 --- /dev/null +++ b/ep_run/prepare_tinystories.py @@ -0,0 +1,40 @@ +"""Char-level TinyStories -> train.bin/val.bin (uint16) + meta.pkl, same format as +shakespeare_char so lt_ep_train.py consumes it via --data. Top-127 chars by train-set +frequency; everything else maps to '?' (keeps the vocab clean of rare unicode).""" +import collections, pickle +import numpy as np +from pathlib import Path + +D = Path('/tmp/lt_ep/data/tinystories') +cnt = collections.Counter() +with open(D / 'train.txt', encoding='utf-8', errors='replace') as f: + while True: + chunk = f.read(1 << 24) + if not chunk: + break + cnt.update(chunk) +keep = sorted(c for c, _ in cnt.most_common(127)) +stoi = {c: i for i, c in enumerate(keep)} +UNK = stoi.get('?', 0) +table = {ord(c): i for c, i in stoi.items()} + + +def enc_file(src, dst): + out = open(dst, 'wb') + n = 0 + with open(src, encoding='utf-8', errors='replace') as f: + while True: + chunk = f.read(1 << 24) + if not chunk: + break + arr = np.fromiter((table.get(ord(c), UNK) for c in chunk), dtype=np.uint16, count=len(chunk)) + arr.tofile(out) + n += len(arr) + out.close() + return n + + +nt = enc_file(D / 'train.txt', D / 'train.bin') +nv = enc_file(D / 'valid.txt', D / 'val.bin') +pickle.dump({'vocab_size': len(stoi), 'stoi': stoi}, open(D / 'meta.pkl', 'wb')) +print(f"vocab={len(stoi)} train_tokens={nt / 1e6:.1f}M val_tokens={nv / 1e6:.1f}M", flush=True) diff --git a/ep_run/prepare_tinystories_bpe.py b/ep_run/prepare_tinystories_bpe.py new file mode 100644 index 0000000..9b03a83 --- /dev/null +++ b/ep_run/prepare_tinystories_bpe.py @@ -0,0 +1,49 @@ +"""TinyStories -> 4k BPE -> train.bin/val.bin (uint16) + meta.pkl + tokenizer.json. +Same bin format as the char pipeline so lt_ep_train consumes it via --data.""" +import pickle +import numpy as np +from pathlib import Path +from tokenizers import Tokenizer +from tokenizers.models import BPE +from tokenizers.trainers import BpeTrainer +from tokenizers.pre_tokenizers import ByteLevel +from tokenizers.decoders import ByteLevel as ByteLevelDec + +SRC = Path('/home/yurenh2/ept/ep_run/data/tsrc') +D = Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe') +D.mkdir(parents=True, exist_ok=True) +VOCAB = 4096 + +tok = Tokenizer(BPE(unk_token=None)) +tok.pre_tokenizer = ByteLevel(add_prefix_space=False) +tok.decoder = ByteLevelDec() +trainer = BpeTrainer(vocab_size=VOCAB, special_tokens=[], show_progress=True) +tok.train([str(SRC / 'train.txt')], trainer) +tok.save(str(D / 'tokenizer.json')) +print(f"trained BPE vocab={tok.get_vocab_size()}", flush=True) + + +def enc_file(src, dst): + out = open(dst, 'wb') + n = 0 + buf = [] + with open(src, encoding='utf-8', errors='replace') as f: + for line in f: + buf.append(line) + if len(buf) >= 20000: + ids = [i for e in tok.encode_batch([''.join(buf)]) for i in e.ids] + np.array(ids, dtype=np.uint16).tofile(out) + n += len(ids) + buf = [] + if buf: + ids = [i for e in tok.encode_batch([''.join(buf)]) for i in e.ids] + np.array(ids, dtype=np.uint16).tofile(out) + n += len(ids) + out.close() + return n + + +nt = enc_file(SRC / 'train.txt', D / 'train.bin') +nv = enc_file(SRC / 'valid.txt', D / 'val.bin') +pickle.dump({'vocab_size': tok.get_vocab_size()}, open(D / 'meta.pkl', 'wb')) +print(f"vocab={tok.get_vocab_size()} train_tokens={nt/1e6:.1f}M val_tokens={nv/1e6:.1f}M", flush=True) diff --git a/ep_run/probe_geometry.py b/ep_run/probe_geometry.py new file mode 100644 index 0000000..971e908 --- /dev/null +++ b/ep_run/probe_geometry.py @@ -0,0 +1,162 @@ +"""Geometric probes for trained models: CKA(model_a, model_b) and effective rank per layer. + +Usage: + python3 probe_geometry.py \ + --ckpts bp:runs_local/probe_bp/ckpt.pt bpfree:runs_local/probe_bpfree/ckpt.pt \ + --data_dir data/tinystories --batch_size 32 --n_batches 4 \ + --out probes/probe_results.json + +Outputs JSON with per-layer effective rank, and CKA matrix between every pair of named ckpts. +""" +import argparse +import json +import pickle +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F + +from model_local import LocalGPTConfig +from train_local_ce import LocalCETransformer + + +def load_model(ckpt_path, device): + blob = torch.load(ckpt_path, map_location=device, weights_only=False) + cfg_dict = blob["config"] + args = blob["args"] + cfg = LocalGPTConfig(**{k: v for k, v in cfg_dict.items() + if k in LocalGPTConfig.__dataclass_fields__}) + model = LocalCETransformer(cfg, translator_rank=args.get("translator_rank", 0)).to(device) + model.load_state_dict(blob["model_state"], strict=False) + model.eval() + return model, cfg, args + + +def get_fixed_batches(data_dir, block_size, batch_size, n_batches, device, seed=12345): + fn = data_dir / "val.bin" + data = np.memmap(fn, dtype=np.uint16, mode="r") + g = torch.Generator().manual_seed(seed) + batches = [] + for _ in range(n_batches): + ix = torch.randint(len(data) - block_size - 1, (batch_size,), generator=g) + x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) + batches.append(x.to(device, non_blocking=True)) + return batches + + +@torch.no_grad() +def collect_activations(model, batches): + """Run model on each batch, return list of per-layer activations stacked across batches. + Returns: list of length L+1, each element is tensor of shape (N_total_tokens, d_model) + """ + per_layer = None + for X in batches: + acts = model.forward_activations(X) # list of (B, T, d), length L+1 + if per_layer is None: + per_layer = [[] for _ in range(len(acts))] + for l, a in enumerate(acts): + per_layer[l].append(a.reshape(-1, a.size(-1)).float().cpu()) + return [torch.cat(parts, dim=0) for parts in per_layer] # list of (N, d) + + +def linear_cka(X, Y, center=True): + """Linear CKA between (N, d_x) and (N, d_y) matrices. + CKA(X,Y) = ||Y^T X||_F^2 / (||X^T X||_F * ||Y^T Y||_F) + """ + if center: + X = X - X.mean(dim=0, keepdim=True) + Y = Y - Y.mean(dim=0, keepdim=True) + XtY = X.T @ Y # (d_x, d_y) + num = (XtY ** 2).sum().item() + XtX = X.T @ X + YtY = Y.T @ Y + den = ((XtX ** 2).sum().sqrt() * (YtY ** 2).sum().sqrt()).item() + return num / max(den, 1e-12) + + +def effective_rank(X, eps=1e-12): + """Effective rank = exp(entropy of normalized eigenvalues).""" + X_centered = X - X.mean(dim=0, keepdim=True) + # Use eigendecomp of covariance for stability + cov = (X_centered.T @ X_centered) / max(X_centered.size(0) - 1, 1) + eigvals = torch.linalg.eigvalsh(cov).clamp_min(0.0) + s = eigvals.sum().item() + if s < eps: + return 0.0 + p = eigvals / s + p = p[p > eps] + H = -(p * p.log()).sum().item() + return float(np.exp(H)) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--ckpts", type=str, nargs="+", required=True, + help="space-separated NAME:PATH pairs, e.g. bp:runs/bp/ckpt.pt bpfree:runs/bpfree/ckpt.pt") + p.add_argument("--data_dir", type=str, default="data/tinystories") + p.add_argument("--batch_size", type=int, default=32) + p.add_argument("--n_batches", type=int, default=4) + p.add_argument("--block_size", type=int, default=512) + p.add_argument("--out", type=str, default="probes/probe_results.json") + args = p.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + data_dir = Path(args.data_dir) + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + + # Parse ckpt names + name_paths = [s.split(":", 1) for s in args.ckpts] + print(f"Probing {len(name_paths)} ckpts on {data_dir}, {args.n_batches} batches × {args.batch_size}") + + # Use block_size from first ckpt's training config + first_blob = torch.load(name_paths[0][1], map_location="cpu", weights_only=False) + block_size = first_blob["config"].get("block_size", args.block_size) + + batches = get_fixed_batches(data_dir, block_size, args.batch_size, args.n_batches, device) + + # Collect activations per ckpt + all_acts = {} # name → list of (N, d) + eff_ranks = {} # name → list of float per layer + for name, ckpt_path in name_paths: + print(f" loading {name} from {ckpt_path}") + model, cfg, train_args = load_model(ckpt_path, device) + acts = collect_activations(model, batches) + all_acts[name] = acts + eff_ranks[name] = [effective_rank(a) for a in acts] + print(f" layers: {len(acts)}, d_model: {acts[0].size(1)}, N: {acts[0].size(0)}") + print(f" eff_rank per layer: {[f'{r:.1f}' for r in eff_ranks[name]]}") + del model + torch.cuda.empty_cache() + + # CKA matrices: for each pair, compute (L+1) × (L+1) CKA matrix + cka_matrices = {} # "name_a:name_b" → list of lists + names = list(all_acts.keys()) + for i in range(len(names)): + for j in range(i, len(names)): + a, b = names[i], names[j] + La, Lb = len(all_acts[a]), len(all_acts[b]) + mat = [[0.0] * Lb for _ in range(La)] + for la in range(La): + for lb in range(Lb): + mat[la][lb] = linear_cka(all_acts[a][la], all_acts[b][lb]) + cka_matrices[f"{a}::{b}"] = mat + # Show diagonal (corresponding layer pairs) if same depth + if La == Lb: + diag = [mat[k][k] for k in range(La)] + print(f" CKA({a},{b}) diag: {[f'{v:.3f}' for v in diag]}") + + results = { + "ckpts": dict(name_paths), + "data_dir": str(data_dir), + "n_total_tokens": all_acts[names[0]][0].size(0), + "effective_rank": eff_ranks, + "cka": cka_matrices, + } + out_path.write_text(json.dumps(results, indent=2)) + print(f"\nWrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/ep_run/profile_ep.log b/ep_run/profile_ep.log new file mode 100644 index 0000000..d9812a4 --- /dev/null +++ b/ep_run/profile_ep.log @@ -0,0 +1,21 @@ +=== full + component toggles (ms/step, B=24, C512) === +/home/yurenh2/miniconda3/lib/python3.13/site-packages/torch/autograd/graph.py:865: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:330.) + return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass +FULL ep_step: 7266 + -jacreg: 7242 + -resreg: 7312 + -t1max(no refine): 5886 + t2sel=80: 7384 + t2sel=40: 4485 + plain nudge holo=0 T2=20: 3179 + free relax T1=150 alone: 740 + free relax T1=300 alone: 1480 +=== batch sweep (full) === + B=8: 2353 ms (294.1/sample) + B=24: 7405 ms (308.5/sample) + B=48: 14496 ms (302.0/sample) +=== compile free relax === + free relax T1=150 COMPILED: 507 +=== bf16 full === + full bf16: ERR RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn +DONE diff --git a/ep_run/profile_ep.py b/ep_run/profile_ep.py new file mode 100644 index 0000000..f76c4f2 --- /dev/null +++ b/ep_run/profile_ep.py @@ -0,0 +1,40 @@ +import torch, time, math +import lt_ep_train as LT +torch.manual_seed(0) +def mk(): + blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick') + blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 + return blk +def T(fn,reps=3,warm=1): + try: + torch.cuda.empty_cache() + for _ in range(warm): fn() + torch.cuda.synchronize(); t0=time.time() + for _ in range(reps): fn() + torch.cuda.synchronize(); return round((time.time()-t0)/reps*1000) + except Exception as e: + return f"ERR {type(e).__name__}: {str(e)[:70]}" +blk=mk(); idx,y=LT.get_batch('train',24,256) +base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=160,t1max=300,res_est=1e-4,resreg=0.2) +S=lambda **kw: (lambda: LT.ep_step(blk,idx,y,**{**base,**kw})) +print("=== full + component toggles (ms/step, B=24, C512) ===",flush=True) +full=T(S()); print(f"FULL ep_step: {full}",flush=True) +for n,kw in [("-jacreg",dict(jacreg=0)),("-resreg",dict(resreg=0)),("-t1max(no refine)",dict(t1max=0)), + ("t2sel=80",dict(t2sel=80)),("t2sel=40",dict(t2sel=40)),("plain nudge holo=0 T2=20",dict(holo=0,t2sel=0))]: + print(f" {n}: {T(S(**kw))}",flush=True) +xin=blk.embed(idx).detach() +print(f" free relax T1=150 alone: {T(lambda: LT.relax(blk,xin.clone(),xin,150,0.1))}",flush=True) +print(f" free relax T1=300 alone: {T(lambda: LT.relax(blk,xin.clone(),xin,300,0.1))}",flush=True) +print("=== batch sweep (full) ===",flush=True) +for B in [8,24,48]: + ib,yb=LT.get_batch('train',B,256); t=T(lambda: LT.ep_step(blk,ib,yb,**base)) + print(f" B={B}: {t} ms"+(f" ({t/B:.1f}/sample)" if isinstance(t,(int,float)) else ""),flush=True) +print("=== compile free relax ===",flush=True) +try: + blk._cstep=torch.compile(lambda z,xn: z+0.1*blk.tforce(z,xn)) + print(f" free relax T1=150 COMPILED: {T(lambda: LT.relax(blk,xin.clone(),xin,150,0.1))}",flush=True); del blk._cstep +except Exception as e: print(f" compile ERR {e}",flush=True) +print("=== bf16 full ===",flush=True) +def bf(): + with torch.autocast('cuda',dtype=torch.bfloat16): LT.ep_step(blk,idx,y,**base) +print(f" full bf16: {T(bf)}",flush=True); print("DONE",flush=True) diff --git a/ep_run/ra_mlp.py b/ep_run/ra_mlp.py new file mode 100644 index 0000000..5e3de91 --- /dev/null +++ b/ep_run/ra_mlp.py @@ -0,0 +1,287 @@ +""" +Reciprocal Alignment (RA) exploration on MLP + MNIST. 5 arms. + +arms: + bp: W via BP, B ignored + fa: W via FA with B=random, B fixed + ra_rev: W via FA with B, B via reverse task FA with W (lambda=0) + ra_recon: W via FA with B, B via layer-local reconstruction (lambda=1) + ra_comb: W via FA with B, B = (1-lam)*rev + lam*recon + +model: pure linear MLP with LayerNorm between layers (whitening for RA-recon fixpoint). +task: MNIST classification, cross-entropy. +diag: test acc, per-layer ||B_l - W_l.T|| / ||W_l|| alignment, losses. +""" +import argparse +import json +import time +from pathlib import Path + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + + +def make_dims(in_dim, hidden, depth, out_dim): + """depth = number of weight matrices. hidden count = depth - 1.""" + if depth < 2: + raise ValueError("depth must be >= 2") + return [in_dim] + [hidden] * (depth - 1) + [out_dim] + + +def init_mats(dims, device, seed, activation="linear"): + g = torch.Generator(device="cpu").manual_seed(seed) + gain = 1.0 if activation in ("linear", "tanh") else (2.0 ** 0.5) # Kaiming for ReLU, Xavier otherwise + W, B = [], [] + for i in range(len(dims) - 1): + d_in, d_out = dims[i], dims[i + 1] + w = torch.empty(d_out, d_in).normal_(std=gain * (1.0 / d_in) ** 0.5, generator=g).to(device) + b = torch.empty(d_in, d_out).normal_(std=gain * (1.0 / d_out) ** 0.5, generator=g).to(device) + W.append(w) + B.append(b) + return W, B + + +def apply_act(z, activation): + if activation == "linear": + return z + if activation == "relu": + return F.relu(z) + if activation == "tanh": + return torch.tanh(z) + raise ValueError(activation) + + +def act_deriv(z, activation): + """d(act)/dz evaluated at z.""" + if activation == "linear": + return torch.ones_like(z) + if activation == "relu": + return (z > 0).float() + if activation == "tanh": + return 1 - torch.tanh(z).pow(2) + raise ValueError(activation) + + +def forward_w(x, W, activation="linear", ln=True): + """MLP forward with activation + optional LN between hidden layers. Returns (acts[0..L], pre[0..L]). + pre[l] for l in 1..L is the pre-activation z[l] = W[l-1] @ a[l-1]. pre[0] = None. + acts[l] for l in 0..L is the post-activation (or input at l=0, logits at l=L). + """ + acts = [x] + pre = [None] + h = x + L = len(W) + for l in range(L): + z = h @ W[l].T # (N, d_{l+1}) + pre.append(z) + if l < L - 1: + h = apply_act(z, activation) + if ln: + h = F.layer_norm(h, h.shape[-1:]) + else: + h = z # logits, no activation + acts.append(h) + return acts, pre + + +def w_grads(acts, pre, grad_top, W, B, use_bp, activation="linear"): + """W update via BP or FA (with B as feedback). Applies activation derivative for hidden layers. + LN derivative is ignored (approximate, consistent BP/FA comparison).""" + L = len(W) + grads = [None] * L + delta = grad_top # (N, d_L), at the final (logit) layer + for l in reversed(range(L)): + grads[l] = delta.T @ acts[l] # (d_{l+1}, d_l) + if l > 0: + if use_bp: + grad_a = delta @ W[l] # (N, d_l) in a-space + else: + grad_a = delta @ B[l].T # FA feedback via B + # Convert grad_a (in activation output space) to grad_z (in pre-activation space) + delta = grad_a * act_deriv(pre[l], activation) + return grads + + +def b_grads_rev(acts, x, W, B, ln_b=True): + """Reverse task FA: B pathway reconstructs x from top, W as feedback matrix. + b[L] = a[L], b[l-1] = LN(B[l-1] @ b[l]) (LN on B pathway mirrors W's LN for stability). + L_rev = ||b[0] - x||^2 (mean over batch) + FA for B: eta[0] = (b[0]-x)/N; eta[l] = eta[l-1] @ W[l-1].T + Update: dB[l] = eta[l].T @ b[l+1] + Note: LN derivative in B pathway is ignored (approximate, consistent with W pathway treatment). + """ + L = len(B) + N = x.shape[0] + b = [None] * (L + 1) + b[L] = acts[L] + for l in range(L, 0, -1): + h = b[l] @ B[l - 1].T # (N, d_{l-1}) + if ln_b and l > 1: # don't LN the last (bottom) output since it targets x directly + h = F.layer_norm(h, h.shape[-1:]) + b[l - 1] = h + eta = [None] * (L + 1) + eta[0] = (b[0] - x) / N # batch-averaged loss gradient + for l in range(1, L + 1): + eta[l] = eta[l - 1] @ W[l - 1].T # (N, d_l) + grads_B = [None] * L + for l in range(L): + grads_B[l] = eta[l].T @ b[l + 1] # (d_l, d_{l+1}) + rev_loss = (b[0] - x).pow(2).sum(-1).mean().item() + return grads_B, rev_loss + + +def b_grads_recon(acts, B): + """Layer-local reconstruction: L_B^l = ||a[l+1] @ B[l].T - a[l]||^2 per layer (mean over batch). + dB[l] = (r.T @ a[l+1]) / N, r = a[l+1] @ B[l].T - a[l] (GD on that quadratic) + """ + L = len(B) + N = acts[0].shape[0] + grads_B = [None] * L + total_loss = 0.0 + for l in range(L): + X = acts[l + 1] # (N, d_{l+1}) + Y = acts[l] # (N, d_l) + r = X @ B[l].T - Y # (N, d_l) + grads_B[l] = (r.T @ X) / N # (d_l, d_{l+1}) + total_loss += r.pow(2).sum(-1).mean().item() + return grads_B, total_loss + + +def alignment(W, B): + """Per-layer ||B[l] - W[l].T||_F / ||W[l]||_F.""" + out = [] + for l in range(len(W)): + diff = B[l] - W[l].T + out.append((diff.norm() / (W[l].norm() + 1e-9)).item()) + return out + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--arm", required=True, choices=["bp", "fa", "ra_rev", "ra_recon", "ra_comb"]) + p.add_argument("--lam", type=float, default=0.5, help="mixing coef for ra_comb (0=pure rev, 1=pure recon)") + p.add_argument("--epochs", type=int, default=10) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=0.05) + p.add_argument("--lr_b", type=float, default=None, help="separate LR for B (defaults to --lr)") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--out", type=str, required=True) + p.add_argument("--data_dir", type=str, default="data/mnist") + p.add_argument("--no_ln", action="store_true", help="disable LayerNorm in forward") + p.add_argument("--log_every", type=int, default=100) + p.add_argument("--depth", type=int, default=4, help="number of weight matrices (>=2)") + p.add_argument("--hidden", type=int, default=256) + p.add_argument("--activation", choices=["linear", "relu", "tanh"], default="linear") + args = p.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + torch.manual_seed(args.seed) + + out_dir = Path(args.out) + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "config.json").write_text(json.dumps(vars(args), indent=2)) + log_path = out_dir / "log.jsonl" + log_path.write_text("") # truncate + + # Data + tfm = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + transforms.Lambda(lambda t: t.view(-1)), + ]) + train_ds = datasets.MNIST(args.data_dir, train=True, download=True, transform=tfm) + test_ds = datasets.MNIST(args.data_dir, train=False, download=True, transform=tfm) + train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) + test_loader = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=2, pin_memory=True) + + dims = make_dims(784, args.hidden, args.depth, 10) + W, B = init_mats(dims, device, args.seed, activation=args.activation) + lr_b = args.lr_b if args.lr_b is not None else args.lr + ln = not args.no_ln + + t0 = time.time() + + def log(rec): + rec["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(rec) + "\n") + + log({"event": "start", "arm": args.arm, "dims": dims, "activation": args.activation, + "lr": args.lr, "lr_b": lr_b, "ln": ln, "lam": args.lam}) + print(f"[{args.arm}] device={device} dims={dims} act={args.activation} ln={ln} lr={args.lr} lr_b={lr_b} lam={args.lam}") + + def test_acc(): + correct = 0 + total = 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + acts, _ = forward_w(x, W, activation=args.activation, ln=ln) + logits = acts[-1] + pred = logits.argmax(-1) + correct += (pred == y).sum().item() + total += y.shape[0] + return correct / total + + step = 0 + use_bp_for_W = (args.arm == "bp") + for epoch in range(args.epochs): + for x, y in train_loader: + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + + acts, pre = forward_w(x, W, activation=args.activation, ln=ln) + logits = acts[-1] + ce = F.cross_entropy(logits, y).item() + N = x.shape[0] + probs = F.softmax(logits, dim=-1) + onehot = F.one_hot(y, num_classes=10).float() + grad_top = (probs - onehot) / N + + # W grads (uses W, B snapshot) + gW = w_grads(acts, pre, grad_top, W, B, use_bp_for_W, activation=args.activation) + + # B grads + gB = None + rev_loss = None + recon_loss = None + if args.arm == "ra_rev": + gB, rev_loss = b_grads_rev(acts, x, W, B) + elif args.arm == "ra_recon": + gB, recon_loss = b_grads_recon(acts, B) + elif args.arm == "ra_comb": + gB_rev, rev_loss = b_grads_rev(acts, x, W, B) + gB_rec, recon_loss = b_grads_recon(acts, B) + gB = [(1 - args.lam) * gB_rev[l] + args.lam * gB_rec[l] for l in range(len(B))] + + # apply updates + for l in range(len(W)): + W[l] -= args.lr * gW[l] + if gB is not None: + for l in range(len(B)): + B[l] -= lr_b * gB[l] + + step += 1 + if step % args.log_every == 0: + align = alignment(W, B) + log({ + "event": "step", "step": step, "epoch": epoch, + "loss_ce": ce, + "rev_loss": rev_loss, "recon_loss": recon_loss, + "alignment": align, + }) + + acc = test_acc() + align = alignment(W, B) + log({"event": "eval", "epoch": epoch, "step": step, "test_acc": acc, "alignment": align}) + print(f"[{args.arm}] epoch {epoch:2d} step {step:5d} test_acc {acc:.4f} align {[f'{a:.3f}' for a in align]}") + + log({"event": "done", "step": step, "final_acc": acc, "final_alignment": align}) + print(f"[{args.arm}] done in {time.time() - t0:.1f}s final_acc={acc:.4f}") + + +if __name__ == "__main__": + main() diff --git a/ep_run/rearm_203.sh b/ep_run/rearm_203.sh new file mode 100644 index 0000000..759c08b --- /dev/null +++ b/ep_run/rearm_203.sh @@ -0,0 +1,8 @@ +L=/home/yurenh2/ept/ep_run/runs/ep_resreg_warm.log +while true; do + b=$(grep -oE "\(best [0-9.]+\)" "$L" | tail -1 | grep -oE "[0-9.]+") + [ -z "$b" ] && { sleep 300; continue; } + if awk "BEGIN{exit !($b < 2.03)}"; then echo "resreg_warm REACHED 2.03: best=$b | $(grep 'val CE' "$L" | tail -1)"; break; fi + pgrep -f "ckpt runs/ep_resreg_warm.pt" >/dev/null || { echo "resreg_warm DIED at best=$b"; break; } + sleep 600 +done diff --git a/ep_run/redx_freezer.py b/ep_run/redx_freezer.py new file mode 100644 index 0000000..962faf2 --- /dev/null +++ b/ep_run/redx_freezer.py @@ -0,0 +1,19 @@ +import time, os, re, shutil +os.chdir("/home/yurenh2/ept/ep_run") +LOG, CK = "runs/ep_redx.log", "runs/ep_redx.pt" +os.makedirs("runs/redx_traj", exist_ok=True) +seen=set(); t0=time.time() +while time.time()-t0 < 8*3600: + time.sleep(20) + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: continue + if not ls: continue + m=re.search(r"step (\d+)/.*res=([\d.eE+-]+)", ls[-1]) + if not m: continue + step=int(m.group(1)); res=float(m.group(2)) + if step not in seen and os.path.exists(CK) and os.path.getsize(CK)>1e6: + try: shutil.copy2(CK, f"runs/redx_traj/s{step}.pt"); seen.add(step); print(f"froze s{step} res={res:.2e}", flush=True) + except Exception: pass + if res>0.25: + print(f"DIVERGED step {step} res={res:.2e}; frozen={sorted(seen)}", flush=True); break +print("freezer done; frozen:", sorted(seen)) diff --git a/ep_run/redx_freezer2.py b/ep_run/redx_freezer2.py new file mode 100644 index 0000000..269d1e5 --- /dev/null +++ b/ep_run/redx_freezer2.py @@ -0,0 +1,21 @@ +import time, os, re, shutil +os.chdir("/home/yurenh2/ept/ep_run") +LOG, CK = "runs/ep_redx.log", "runs/ep_redx.pt" +os.makedirs("runs/redx_traj", exist_ok=True) +seen = set(); t0 = time.time() +while time.time() - t0 < 8 * 3600: + time.sleep(15) + try: ls = [l for l in open(LOG) if l.startswith("step")] + except Exception: continue + if not ls: continue + m = re.search(r"step\s+(\d+)/.*res=([\d.eE+-]+)", ls[-1]) # \s+ : log right-aligns the step number + if not m: continue + step = int(m.group(1)); res = float(m.group(2)) + if step not in seen and os.path.exists(CK) and os.path.getsize(CK) > 1e6: + try: + shutil.copy2(CK, f"runs/redx_traj/s{step}.pt"); seen.add(step) + print(f"froze s{step} res={res:.2e}", flush=True) + except Exception: pass + if res > 0.25: + print(f"DIVERGED step {step} res={res:.2e}; frozen={sorted(seen)}", flush=True); break +print("freezer done; frozen:", sorted(seen)) diff --git a/ep_run/redx_trajprobe.py b/ep_run/redx_trajprobe.py new file mode 100644 index 0000000..ee88098 --- /dev/null +++ b/ep_run/redx_trajprobe.py @@ -0,0 +1,47 @@ +"""Live trajectory prober for ep_redx: probe each frozen ckpt for cos(g_EP,g_transpose), +overlay with val/res from the training log -> the step|res|cos ordering through the divergence. +Probes live (shares GPU0, slows the run ~1.4x); finishes when ep_redx exits (diverges).""" +import time, os, re, subprocess, glob +os.chdir("/home/yurenh2/ept/ep_run") +LOG, OUT, PID = "runs/ep_redx.log", "runs/redx_traj.log", 2497442 +def alive(): + try: os.kill(PID, 0); return True + except Exception: return False +def resmap(): + M = {} + try: + for l in open(LOG): + if l.startswith("step"): + m = re.search(r"step (\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", l) + if m: M[int(m.group(1))] = (m.group(2), m.group(3)) + except Exception: pass + return M +def probe(ck, done): + step = int(re.search(r"s(\d+)", ck).group(1)) + if step in done: return + done.add(step) + val, res = resmap().get(step, ("?", "?")) + cosv, zres = "?", "?" + env = dict(os.environ, CUDA_VISIBLE_DEVICES="0", PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True") + try: + r = subprocess.run(["python3", "asym_probe.py", "--ckpt", ck, "--B", "8"], + env=env, capture_output=True, text=True, timeout=400) + out = r.stdout + r.stderr + m = re.search(r"cos\(g_EP, ?g_transpose\)=([+-][0-9.]+)", out); cosv = m.group(1) if m else "?" + z = re.search(r"step_res=([0-9.eE+-]+)", out); zres = z.group(1) if z else "?" + except Exception: cosv = "err" + line = f" {step:5d} | val {val} | res(log) {res} | cos {cosv} | z*res {zres}" + open(OUT, "a").write(line + "\n"); print(line, flush=True) +def cks(): + return sorted(glob.glob("runs/redx_traj/s*.pt"), key=lambda p: int(re.search(r"s(\d+)", p).group(1))) +open(OUT, "a").write("# step | val | res(log) | cos(g_EP,g_transpose) | z*res(probe)\n") +done = set(); t0 = time.time() +while time.time() - t0 < 6 * 3600: + time.sleep(30) + for ck in cks(): probe(ck, done) + if not alive(): + time.sleep(45) # let freezer catch the final ckpts + for ck in cks(): probe(ck, done) + break +print("=== TRAJPROBE DONE — full trajectory ===") +for l in open(OUT): print(l.rstrip()) diff --git a/ep_run/resreg_probe.py b/ep_run/resreg_probe.py new file mode 100644 index 0000000..8a5ba0a --- /dev/null +++ b/ep_run/resreg_probe.py @@ -0,0 +1,62 @@ +"""Does resreg CONTAMINATE the gradient or add back BPTT's missing residual-defense? +At a ckpt compute the true grad g_BPTT, the pure EP estimate g_VF, and the resreg grad g_R (at the +training scale lam). If cos(g_VF + lam*g_R, g_BPTT) >= cos(g_VF, g_BPTT), resreg moves EP TOWARD the +true gradient (correction). If it drops, resreg is contaminating.""" +import argparse, pickle, math, torch +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock, ep_step, bptt_step, relax + +ap = argparse.ArgumentParser() +ap.add_argument('--ckpt', required=True); ap.add_argument('--data', default='data/tinystories_bpe') +ap.add_argument('--gelu', default='erf'); ap.add_argument('--C', type=int, default=512) +ap.add_argument('--H', type=int, default=16); ap.add_argument('--Mm', type=int, default=256) +ap.add_argument('--T', type=int, default=256); ap.add_argument('--B', type=int, default=8) +ap.add_argument('--T1', type=int, default=150); ap.add_argument('--T2', type=int, default=20) +ap.add_argument('--eps', type=float, default=0.1); ap.add_argument('--beta', type=float, default=0.02) +ap.add_argument('--resreg', type=float, default=0.2); ap.add_argument('--t1max', type=int, default=300) +ap.add_argument('--res_est', type=float, default=1e-4); ap.add_argument('--t2sel', type=int, default=40) +ap.add_argument('--hr', type=float, default=0.02) +cfg = ap.parse_args(); dev = 'cuda' +L.DD = Path(cfg.data); L.vocab = pickle.load(open(L.DD / 'meta.pkl', 'rb'))['vocab_size'] +torch.manual_seed(0) +blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=1.0, attn_mode='thick') +blk.qknorm = True; blk.fnoise = 0.0; blk._cstep = None; blk.navg = 1; blk.li_avg = 0; blk.track = True +blk.nbrake = 0.0; blk.gelu = cfg.gelu +ck = torch.load(cfg.ckpt, map_location=dev) +with torch.no_grad(): + for p, w in zip(blk.allp, ck['allp']): p.copy_(w.to(dev)) +idx, y = L.get_batch('train', cfg.B, cfg.T) + +def flat(gd, ps): + return torch.cat([gd[id(p)].reshape(-1) if gd.get(id(p)) is not None + else torch.zeros(p.numel(), device=dev) for p in ps]) +def cos(a, b): return (a @ b / (a.norm() * b.norm() + 1e-20)).item() + +gB = bptt_step(blk, idx, y, cfg.T1, cfg.eps, 0.0) # TRUE gradient +gVF, _ = ep_step(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, 0.0, 2, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel, 1, 0.0) +xin0 = blk.embed(idx).detach() +zT = relax(blk, xin0.clone(), xin0, cfg.T1, cfg.eps) +resT1 = (relax(blk, zT, xin0, 1, cfg.eps) - zT).norm().item() / (zT.norm().item() + 1e-9) +with torch.enable_grad(): + Fz = blk.tforce(zT, xin0); Rr = (cfg.eps * Fz).pow(2).sum() / (zT.pow(2).sum() + 1e-9) + grr = torch.autograd.grad(Rr, blk.block, allow_unused=True) +gR = {id(p): (g if g is not None else torch.zeros_like(p)) for p, g in zip(blk.block, grr)} +B, VF, R = flat(gB, blk.block), flat(gVF, blk.block), flat(gR, blk.block) +ratio = cfg.resreg * min(1.0, resT1 / 2e-2) +lam = ratio * VF.norm() / (R.norm() + 1e-20) +TOT = VF + lam * R +print(f"# ckpt step {ck.get('step')} best {ck.get('best')} resT1={resT1:.2e} ratio={ratio:.3f}") +print(f"|VF|={VF.norm():.2e} |lam*R|={ (lam*R.norm()).item():.2e} realized ratio={(lam*R.norm()/VF.norm()).item():.3f}") +print(f"cos(VF, BPTT) = {cos(VF, B):+.4f} <- EP estimate, NO resreg") +print(f"cos(VF+lam*R, BPTT) = {cos(TOT, B):+.4f} <- WITH resreg (training grad)") +d = cos(TOT, B) - cos(VF, B) +print(f" delta = {d:+.4f} => {'resreg ADDS alignment (correction, not contamination)' if d >= -1e-3 else 'resreg HURTS alignment (CONTAMINATION)'}") +print(f"cos(R, BPTT) = {cos(R, B):+.4f} <- resreg dir vs true grad (aligned? >0 means resreg points toward BPTT)") +print(f"cos(R, VF) = {cos(R, VF):+.4f}") +# --- M = g_BPTT - g_EP : the finite-horizon stabilizer BPTT HAS and EP LACKS --- +M = B - VF +print(f"--- M = g_BPTT - g_EP : what EP is missing vs BPTT ---") +print(f"|M|/|BPTT| = {(M.norm()/(B.norm()+1e-20)).item():.3f} |M|/|EP| = {(M.norm()/(VF.norm()+1e-20)).item():.3f}") +print(f"cos(M, resreg R) = {cos(M, R):+.4f} <- does resreg point where the MISSING term is? (>0 = resreg's intent correct)") +print(f"cos(M, BPTT) = {cos(M, B):+.4f} cos(M, EP) = {cos(M, VF):+.4f}") diff --git a/ep_run/resreg_warm_probe_loop.py b/ep_run/resreg_warm_probe_loop.py new file mode 100644 index 0000000..21c73a2 --- /dev/null +++ b/ep_run/resreg_warm_probe_loop.py @@ -0,0 +1,49 @@ +import time, os, re, torch, pickle, numpy as np, subprocess +from pathlib import Path +from scipy.sparse.linalg import eigs, LinearOperator +import lt_ep_train as L +from lt_ep_train import EQBlock, relax +os.chdir("/home/yurenh2/ept/ep_run") +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=24; T=256; eps=0.1 +CK='runs/ep_resreg_warm.pt'; LOG='runs/ep_resreg_warm.log'; OUT='runs/resreg_warm_probe.log' +torch.manual_seed(0); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx + +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_resreg_warm.pt"],capture_output=True).returncode==0 +def curinfo(): + try: + ls=[l for l in open(LOG) if l.startswith("step")] + m=re.search(r"step\s+(\d+)",ls[-1]); b=re.search(r"best ([\d.]+)",ls[-1]); v=re.search(r"val CE ([\d.]+)",ls[-1]) + return (m.group(1), b.group(1), v.group(1)) + except Exception: return ("?","?","?") +def probe(): + ck=torch.load(CK,map_location=dev) + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.fnoise=0.0 + with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + xin=blk.embed(idx).detach() + z=relax(blk,xin.clone(),xin,300,eps).detach() + g=(blk.force(z,xin).detach().norm()/z.norm()).item() + F0=blk.force(z,xin).detach(); h=1e-3*z.norm().item(); N=z.numel() + def Jv(v): + vt=torch.from_numpy(v).float().to(dev).view_as(z); nv=vt.norm().item() + if nv<1e-20: return np.zeros_like(v) + return ((blk.force(z+h*(vt/nv),xin).detach()-F0)/h*nv).view(-1).cpu().numpy() + lam=eigs(LinearOperator((N,N),matvec=lambda v:v+eps*Jv(v),dtype=np.float32),k=6,which='LM',return_eigenvectors=False,maxiter=300,tol=1e-3) + mu=(lam-1)/eps; top=mu[np.argmax(mu.real)] + return g, float(top.real), float(abs(top.imag)) + +t0=time.time() +hdr="=== resreg_warm g/Reμ probe loop (every 20min) — track if g holds (bounded) or grows (toward blow) ===" +print(hdr,flush=True); open(OUT,'a').write(hdr+"\n") +while time.time()-t0 < 24*3600: + if not alive(): + line=f"[+{int(time.time()-t0)}s] ep_resreg_warm DEAD/done -> stop"; print(line,flush=True); open(OUT,'a').write(line+"\n"); break + try: + st,bb,vv=curinfo(); g,re_mu,im_mu=probe() + line=f"[+{int((time.time()-t0)/60)}min] step{st} best{bb} valCE{vv} | g_floor={g:.2e} Re_mu={re_mu:+.3f} Im={im_mu:.3f}" + print(line,flush=True); open(OUT,'a').write(line+"\n") + except Exception as e: + print("probe skip:",repr(e),flush=True) + time.sleep(1200) +print("loop end",flush=True) diff --git a/ep_run/sample_eq.py b/ep_run/sample_eq.py new file mode 100644 index 0000000..9a6552a --- /dev/null +++ b/ep_run/sample_eq.py @@ -0,0 +1,70 @@ +"""Autoregressive sampling from an equilibrium LM checkpoint: causal coupling means the prefix +settles independently of padding, so we relax the padded sequence and read the last position's +logits each step (simple full-re-settle variant; incremental KV-style settling is the optimized +version, not needed at this scale).""" +import argparse, pickle, torch +import lt_ep_train as M +from pathlib import Path + +ap = argparse.ArgumentParser() +ap.add_argument('--ckpt', required=True) +ap.add_argument('--data', default='/tmp/lt_ep/data/tinystories') +ap.add_argument('--C', type=int, default=256) +ap.add_argument('--H', type=int, default=8) +ap.add_argument('--T', type=int, default=256) +ap.add_argument('--T1', type=int, default=150) +ap.add_argument('--eps', type=float, default=0.1) +ap.add_argument('--temp', type=float, default=0.8) +ap.add_argument('--topk', type=int, default=40) +ap.add_argument('--c', type=float, default=1.0) +ap.add_argument('--qknorm', action='store_true') +ap.add_argument('--n', type=int, default=3) +ap.add_argument('--prompt', default='Once upon a time') +ap.add_argument('--use_pema', action='store_true') +cfg = ap.parse_args() + +M.DD = Path(cfg.data) +meta = pickle.load(open(M.DD / 'meta.pkl', 'rb')) +M.vocab = meta['vocab_size'] +_tokj = M.DD / 'tokenizer.json' +if _tokj.exists(): # BPE: decode via tokenizers + from tokenizers import Tokenizer + _tk = Tokenizer.from_file(str(_tokj)) + def encode(s): return _tk.encode(s).ids + def decode(ids): return _tk.decode(ids) +else: # char: decode via stoi/itos + stoi = meta['stoi']; itos = {i: c for c, i in stoi.items()} + def encode(s): return [stoi.get(c, 0) for c in s] + def decode(ids): return ''.join(itos.get(int(i), '?') for i in ids) +from lt_ep_train import EQBlock, relax + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(1234) +blk = EQBlock(cfg.C, cfg.H, 256, cfg.T, attn_mode='thick') +blk.c = cfg.c # MUST match training c (force identity) +blk.qknorm = cfg.qknorm # MUST match training qknorm (else wrong fixed point) +ck = torch.load(cfg.ckpt) +src = ck['pema'] if (cfg.use_pema and ck.get('pema') is not None) else ck['allp'] +with torch.no_grad(): + for p, w in zip(blk.allp, src): + p.copy_(w.to(dev)) +print(f"loaded {cfg.ckpt} (step {ck.get('step')}, best {ck.get('best'):.4f}, " + f"{'pema' if cfg.use_pema else 'raw'} weights)\n", flush=True) + +for s in range(cfg.n): + ids = encode(cfg.prompt) + idx = torch.zeros(1, cfg.T, dtype=torch.long, device=dev) + idx[0, :len(ids)] = torch.tensor(ids, device=dev) + pos = len(ids) + while pos < cfg.T: + xin = blk.embed(idx).detach() + z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) + logits = (z[0, pos - 1] @ blk.Wh) / cfg.temp + if cfg.topk > 0: + kth = torch.topk(logits, cfg.topk).values[-1] + logits[logits < kth] = float('-inf') + nxt = torch.multinomial(torch.softmax(logits, -1), 1).item() + idx[0, pos] = nxt + pos += 1 + text = decode(idx[0, :pos].tolist()) + print(f"--- sample {s+1} ---\n{text}\n", flush=True) diff --git a/ep_run/scurria_nonconservative.txt b/ep_run/scurria_nonconservative.txt new file mode 100644 index 0000000..fa6d620 --- /dev/null +++ b/ep_run/scurria_nonconservative.txt @@ -0,0 +1,1708 @@ + Equilibrium Propagation for Non-Conservative Systems + + + Antonino Emanuele Scurria 1 Dimitri Vanden Abeele 1 Bortolo Matteo Mognetti 2 Serge Massar 1 + + + Abstract from inference, the transmission of nonlocal error signals, + and synchronous layer-wise computations with explicit gra- + Equilibrium Propagation (EP) is a physics- + dient storage. These constraints have no clear analog in + inspired learning algorithm that uses stationary + physical systems, making backpropagation challenging to + states of a dynamical system both for inference + implement in neuromorphic or analog hardware. Conse- +arXiv:2602.03670v2 [cs.LG] 1 Jun 2026 + + + + + and learning. In its original formulation it is + quently, understanding how credit assignment can instead + limited to conservative systems, i.e. to dynam- + emerge from intrinsic system dynamics, through local inter- + ics which derive from an energy function. Given + actions and continuous relaxation, is a central question in + their applications, it is important to extend EP + neuroscience and machine learning. + to non-conservative systems, i.e. systems with + non-reciprocal interactions. Previous attempts to Equilibrium Propagation (EP) (Scellier &Bengio, 2017) + generalize EP to such systems failed to compute represents one of the most promising advances in this direc- + the exact gradient of the cost function. Here we tion. It formulates supervised learning as a contrast between + propose a framework that extends EP to arbitrary two stationary states of a dynamical system: a ‘free’ phase + non-conservative systems, including feedforward where the system evolves autonomously, and a ‘nudged’ + networks. We keep the key property of equilib- phase where outputs are weakly pushed toward their targets. + rium propagation, namely the use of stationary The local change in neural states between these phases re- + states both for inference and learning. However, covers the exact gradient of the cost function with respect to + we modify the dynamics in the learning phase by parameters. This enables spatially local learning exploiting + a term proportional to the non-reciprocal part of the continuous relaxation of the system without a distinct + the interaction so as to obtain the exact gradient backward circuit or explicit weight transport. + of the cost function. This algorithm can also be + Since its introduction, several works have sought to im- + derived using a variational formulation that gen- + prove the practicality and biological realism of EP. Algo- + erates the learning dynamics through an energy + rithmic adaptations include enforcing temporal locality to + function defined over an augmented state space. + avoid state storage (Ernoult et al., 2020; Falk et al., 2025), + Numerical experiments show that this algorithm + deriving agnostic updates for black-box energies (Scel- + achieves better performance and learns faster than + lier et al., 2022), and substituting nudging with clamping + previous proposals. + (Stern et al., 2021). Theoretically, the framework has been + extended to stochastic systems (Scellier &Bengio, 2017; + Massar &Mognetti, 2025) and Lagrangian dynamics for + 1. Introduction time-varying inputs (Massar, 2025; Pourcel et al., 2025; + Standard neural network optimization relies on error back- Berneman &Hexner, 2025). In parallel, simulations have + propagation, an algorithm whose computational mechanism explored suitable substrates, ranging from spiking (Mar- + is difficult to reconcile with biological (Crick, 1989) and tin et al., 2021; O’Connor et al., 2019) and resistive net- + physical implementations (Indiveri &Liu, 2015). Specif- works (Kendall et al., 2020) to coupled oscillators (Wang + ically, backpropagation requires a backward pass distinct et al., 2024; Rageau &Grollier, 2025), as well as quantum + systems (Wanjura &Marquardt, 2025; Massar &Mognetti, + 1 + Laboratoire d’Information Quantique (LIQ) CP224, Université 2025; Scellier, 2024). Experimental realizations have been + libre de Bruxelles (ULB), Av. F. D. Roosevelt 50, 1050 Bruxelles, demonstrated in memristor crossbars (Yi et al., 2023), self- + Belgium 2 Interdisciplinary Center for Nonlinear Phenomena and + Complex Systems CP231, Université libre de Bruxelles (ULB), Av. adjusting electrical circuits (Dillavou et al., 2022; 2024), + F. D. Roosevelt 50, 1050 Bruxelles, Belgium. Correspondence to: elastic networks (Altman et al., 2024), and classical Ising + Antonino Emanuele Scurria <antonino.scurria@ulb.be>. models trained on quantum annealers (Laydevant et al., + 2024). + Proceedings of the 43 rd International Conference on Machine + Learning, Seoul, South Korea. PMLR 306, 2026. Copyright 2026 Despite these recent developments and the theoretical el- + by the author(s). + + 1 + Equilibrium Propagation for Non-Conservative Systems + +egance of EP, its standard formulation remains restricted a framework where the original dynamics serve for infer- +to conservative systems. In these systems, dynamics are ence, while a new augmented dynamic is used to compute +derived from an energy function, which inherently enforces gradients of the cost Eq. (2). In this augmented phase, the +symmetry (e.g., symmetric synaptic connections Jij = Jji ) output neurons are nudged towards their targets (as in stan- +through the action-reaction principle. This constraint pre- dard EP), while a local corrective term – proportional to the +cludes the use of EP in a broad class of models characterized antisymmetric part of the Jacobian at the free equilibrium + ∂ +by non-conservative forces. This includes the feedforward JF (x0 , θ, u) = ∂x F (x0 , θ, u) – is added to the forces. The +architectures dominant in modern AI, biological circuits, exact gradients of the cost with respect to parameters are +as well as physical systems that reach stationary states far then obtained by contrasting stationary states of the aug- +from thermodynamic equilibrium, such as nonlinear optical mented system. +systems driven by external lasers (Cin et al., 2025), opto- + Second, we introduce Dyadic EP, a ‘variational’ approach +electronic systems (Kalinin et al., 2025), exciton-polariton + to learning in non-conservative systems. This method in- +condensates (Sajnok &Matuszewski, 2025), active meta- + volves doubling the number of variables in the system’s +materials (Brandenbourger et al., 2019) and active colloids + state space and subsequently introducing a new energy func- +(Bishop et al., 2023; Osat &Golestanian, 2023) (see (Bowick + tion in this extended space. This approach takes advantage +et al., 2022) for a review). + of the extended space to execute the positive and negative +Formally, we consider a dynamical system governed by nudging phases in parallel, recovering the same computa- +a non-reciprocal force field F (x, θ, u), which relaxes to a tional cost as AsymEP. Derived from first principles, this +stationary configuration x0 satisfying: approach is inspired by established methods for mapping + dissipative dynamical systems onto conservative ones by + F (x0 , θ, u) = 0, (1) doubling the degrees of freedom (Bateman, 1931; Galley, +where x represents the state variables, θ the learnable param- 2013; Aykroyd et al., 2025). A more comprehensive study +eters and u the static input. Our goal, given a target y(u), is of the theoretical framework and its application to feedfor- +to compute the gradient of the cost function C(x0 , y) at this ward networks can be found in (Scurria, 2026). Our method +equilibrium, is related to the Dual Propagation algorithm (Høier et al., + dC 0 2023; Høier &Zach, 2023; 2024) and constitutes an inde- + (x , y), (2) pendent, first-principles generalization of Dyadic Learning + dθ +and update θ to minimize the cost. (Nest &Høier; Høier et al., 2024)—previously limited to + Hopfield networks—to arbitrary force fields. +Previous attempts to extend EP to non-conservative dynam- +ics include the Vector Field (VF) algorithm (Scellier et al., Third, we validate our framework on MNIST (LeCun, 1998), +2018). However, as noted by the authors, this method pro- Fashion-MNIST, and CIFAR-10. In continuous Hopfield +vides an unbiased gradient of the cost Eq. (2) only in the networks initialized with symmetric connection matrices, +conservative case. To mitigate this, (Laborieux &Zenke, AsymEP achieves better accuracy and learns faster than +2024) proposed adding a penalty to keep the Jacobian close EP and VF. Additionally, when we constrain the network +to symmetry, essentially forcing the system to be as con- to have a strong degree of structural asymmetry, in which +servative as possible. Alternative methods related to VF, case EP is inapplicable, AsymEP outperforms VF. Finally, +which similarly do not compute the exact gradient, were when we restrict connections to a feedforward structure, our +proposed in (Farinha et al., 2020; Costa &Santos, 2025) and algorithm effectively trains all parameters; in contrast, VF +for specific systems in simulation (Cin et al., 2025; Sajnok is limited to training the last layer, acting essentially as an +&Matuszewski, 2025). Extreme Learning Machine (Huang et al., 2006; Wang et al., + 2022) with poor performance. +Conversely, generalizations of backpropagation can handle +non-reciprocal forces and compute the exact gradient of In summary, this theoretical work proposes two generaliza- +the cost Eq. (2) but inherit the same challenges in physical tions of EP beyond conservative systems to arbitrary differ- +implementations. For instance, Backpropagation Through entiable dynamics that compute in their stationary states. +Time (Werbos, 1990) unfolds the network in time to ap- +ply standard backpropagation, Recurrent Backpropagation 2. Equilibrium Propagation Overview +(Almeida, 1990; Pineda, 1987) avoids this memory require- +ment but still requires a specific circuit to propagate errors, 2.1. Conservative Systems +and the continuous Adjoint Method (Chen et al., 2018) addi- We first review standard Equilibrium Propagation (EP) +tionally requires integrating the dynamics backward in time (Scellier &Bengio, 2017). We consider a network described +which is not physically possible for a dissipative system. by an energy function E(x, θ, u), such that the force field is +In this paper, we first propose Asymmetric EP (AsymEP), + + 2 + Equilibrium Propagation for Non-Conservative Systems + +derived from the potential E: stationary point, i.e., that Eq. (7) holds. Second, EP implic- + ∂ + itly assumes that the Jacobian JE (x0 , u) = ∂x FE (x0 , u) is + ∂ + FE (x, θ, u) = − E(x, θ, u). (3) invertible. In this work, we assume this condition holds and + ∂x will not state it explicitly hereafter. Third, for simplicity, +The objective is to compute the total gradient dC 0 we omit the dependency on the input u and target y in the + dθ (x , y) of a +(quadratic) cost function C(x, y) evaluated at the minimum following equations. +energy configuration of the system. This free equilibrium +denoted x0 (which depend implicitly in θ and u), satisfies 2.2. Vector Field +the stationarity condition: + The Vector Field (VF) algorithm, introduced in (Scellier + ∂ et al., 2018), is an early attempt to adapt EP to non- + − E(x0 , θ, u) = 0. (4) reciprocal forces. This method relies on the observation + ∂x + that, for conservative systems, linearizing the right-hand +To compute gradients, we introduce the augmented energy side of Eq. (9) around the equilibrium point x0 yields +functional: + ∂E(xβ , θ) ∂E(x−β , θ) + + ET (x, θ, β, u, y) = E(x, θ, u) + βC(x, y), (5) 1 + lim − + β→0 2β ∂θ ∂θ +where β is a scalar nudging parameter. The stationary config- ⊤ β (10) + x − x−β + + ∂FE 0 +uration of this augmented system is obtained by integrating = lim − (x , θ) , + β→0 ∂θ 2β +the dynamics + dx ∂ET (x, θ, β, u) where FE = −∂x E(x, θ) is the conservative force. It is + =− , (6) therefore tempting to use the right-hand side of Eq. (10) for + dt ∂x + parameter updates of non-conservative systems, for which +until the energy minimum is reached. This new fixed point no energy function E exists. +xβ , called nudged equilibrium, satisfies: + The VF algorithm adopts precisely this approach. It uses + ∂E(xβ , θ, u) ∂C(xβ , y) the nudged counterpart of Eq. (7), + +β = 0. (7) + ∂x ∂x + ∂C β + F (xβ , θ) − β (x ) = 0, (11) +The training procedure, as improved in (Laborieux et al., ∂x +2021), uses two nudged phases with factors ±β (with +β ̸= 0). Starting from x0 , the system relaxes to two in conjunction with the learning rule Eq. (10): +nearby perturbed equilibria, x+β and x−β . The displace- ⊤ β + x − x−β + +ment x+β − x−β is then used to compute the parameter ∂F 0 + ∆θ = ϵ (x , θ) . (12) +update in the learning rule: ∂θ 2β + + 1 ∂E(xβ , θ, u) ∂E(x−β , θ, u) + + ∆θ = −ϵ − , (8) However, as noted in (Scellier et al., 2018), Eq. (12) does + 2β ∂θ ∂θ not align with the true gradient dC 0 + dθ (x ) and is exact only if +where ϵ > 0 is the learning rate. The theoretical foundation the force is conservative. To see this, let JF (x, θ) denote +of EP is the result that, in the limβ→0 of Eq. (8), we get: the Jacobian of the vector field F (x, θ) (in components + (JF (x, θ))ij = ∂F∂xi (x,θ) + j + ). Differentiating the equilibrium + dC(x0 , y) d ∂E(xβ , θ, u) 0 + condition F (x , θ) = 0 with respect to θ gives + = , (9) + dθ dβ ∂θ + dx0 ∂F 0 +see Appendix D.1. The error of the above method is O(β 2 ). JF (x0 , θ) + (x , θ) = 0. (13) +This error can be further reduced using holomorphic equi- dθ ∂θ +librium propagation (Laborieux &Zenke, 2022). Consequently, the exact gradient of the cost is +Thus, EP recovers the exact gradient of the cost function ⊤ +using only local computations. In this manner, learning dC 0 dx0 ∂C 0 + (x ) = (x ) +implements gradient descent without an explicit backward dθ dθ ∂x + ⊤ +pass, and credit assignment is realized through the system’s + + ∂F 0 ⊤ 0 + −1 ∂C 0 +intrinsic relaxation dynamics. =− (x , θ) JF (x , θ) (x ) . + ∂θ ∂x + | {z }| {z } +Three remarks can be made at this point. First, EP does not pre-synaptic post-synaptic +require the system to be at an energy minimum, but only at a (14) + + 3 + Equilibrium Propagation for Non-Conservative Systems + +The terms ’pre-synaptic’ and ’post-synaptic’ in Eq. (14) Algorithm 1 Asymmetric EP (AsymEP) +are used by analogy with neuronal transmission: the pre- 1: Inputs: Force field F (x, θ), cost function C(x), nudg- +synaptic factor captures the local influence of θ on the force ing parameter β, learning rate ϵ. +F , while the post-synaptic factor is the sensitivity of the 2: repeat +cost to state perturbations. 3: 1. Free Phase: Evolve to stationary state +If instead we differentiate the nudged equilibrium condition 4: Evolve the system dynamics +in Eq. (11) with respect to β and evaluate at β = 0, we 5: + dx +obtain = F (x, θ), (17) + β + dt + dx ∂C 0 + JF (x0 , θ) − (x ) = 0, (15) 6: until convergence to the stationary state x0 . + dβ β=0 ∂x 7: 2. Jacobian Decomposition +which gives 8: Compute the Jacobian at equilibrium: + 9: + dxβ −1 ∂C 0 ∂F 0 + = JF (x0 , θ) (x , y). (16) JF (x0 , θ) = (x , θ), (18) + dβ β=0 ∂x ∂x +The right-hand side of Eq. (16) represents the effective post- 10: and decompose it in its antisymmetric part: + 11: +synaptic term used by the VF algorithm (Eq. 12). Compar- +ing this with the exact post-synaptic term derived in Eq. (14), + AJ (x0 , θ) = 12 (JF (x0 , θ) − JF (x0 , θ)⊤ ). (19) +we see that they coincide only if JF = JF⊤ , i.e., only if the +system is conservative. 12: 3. Nudged Phase: Augmented Dynamics +Now, let SJ (x0 , θ) and AJ (x0 , θ) denote the symmetric 13: Integrate the dynamics twice starting from x0 +and antisymmetric parts of the Jacobian at the free (un- 14: +nudged) equilibrium, respectively. Then, we show in Ap- dx ∂C +pendix A that the gradient error increases with the spectral = F (x, θ) − β (x) − 2AJ (x0 , θ) (x − x0 ), + −1 dt ∂x +radius of SJ (x0 , θ) AJ (x0 , θ). Consequently, large (20) +antisymmetric contributions degrade the gradient estima- 15: until convergence to two new stationary states +tion, confirming empirical observations in the Appendix of x±β + A . +(Ernoult et al., 2020). In fact, in the pathological limit where 16: 4. Parameter Update +the Jacobian would be purely antisymmetric SJ (x0 , θ) = 0, 17: Update the parameters according to: +the update of VF gives the negative of the true gradient, 18: +maximizing the cost rather than minimizing it. ⊤ ! + xβA − x−β + + ∂F 0 A + ∆θ = ϵ (x , θ) . (21) + ∂θ 2β +3. Asymmetric EP +Here, we introduce Asymmetric EP (AsymEP), see Algo- 19: until convergence of θ +rithm 1, which removes the gradient estimate error inherent 20: Output: Optimized parameters θ. +to VF by adding a local correction term to the augmented +inference dynamics. The new nudged equilibrium xβA satis- +fies: where JFA (x, θ) is the Jacobian of the modified dynamical + ∂C β system Eq. (20). At the equilibrium point x0 , JFA is equal + F (xβA , θ) − β (x ) − 2AJ (x0 , θ) (xβA − x0 ) = 0, (22) to the transpose of the original Jacobian: + ∂x A +As in VF, we then obtain two perturbed states x±β + A for op- + JFA (x0 , θ) = JF (x0 , θ) − 2AJ (x0 , θ) +posite nudging ±β and apply the contrastive learning rule = SJ (x0 , θ) − AJ (x0 , θ) +of Eq. (12). = JF⊤ (x0 , θ). (24) +We now show that AsymEP gives rise to the correct learning +rule, i.e. that right-hand side of Eq. (21) is proportional to where we have used the decomposition Eq. (44) of the orig- +the gradient of the cost function dC 0 inal Jacobian J into its symmetric and antisymmetric com- + dθ (x ) at the equilibrium + 0 +point x (Eq. 14). To this end, note that the same reasoning ponents. Therefore, the left hand side of Eq. (23) is equal to +leading to Eq. (16) leads to the true post-synaptic term + + dxβA −1 ∂C 0 dxβA −1 ∂C 0 + = JFA (x0 , θ) (x ). (23) = JF⊤ (x0 , θ) (x ), (25) + dβ β=0 ∂x dβ β=0 ∂x + + 4 + Equilibrium Propagation for Non-Conservative Systems + +which, using Eq. (14), proves the result. Additionally, al- until a stationary point (z β , z ′β ) is reached. Upon conver- +though implied by the equality with the true gradient, we gence, we follow the standard EP paradigm in using the +explicitly demonstrate the equivalence of the gradient esti- difference z β − z ′β to compute the post-synaptic term. Un- + ′ + ′ +mates obtained by AsymEP and Backpropagation Through der the change of variables m = z+z 2 and d = z − z , we +Time in Appendix B following (Ernoult et al., 2019). prove in Appendix D that m follows the original dynamics + F (ensuring valid inference), while d relaxes to a "physical" +Note that the corrective term −2AJ (x0 , θ)(x − x0 ) in + error signal proportional to the cost gradient. +Eq. (20) is spatially local: AJ vanishes for unconnected +neurons, and (x − x0 ) is available at the synapse given the It is important to notice that while Dyadic EP introduces a +memory mechanism already required by Eq. (12). This distinct formulation, it remains consistent with the general +correction can create backward connections (Section 5.3). theoretical setting of EP and matches the computational +However, in physical realizations, both feedforward and cost of AsymEP. Note also that we start the evolution of +feedback connections must be physically present, though the free phase (β = 0) with the identical initial condition +feedback may be deactivated during inference. for z and z ′ , (i.e., d = 0). This guarantees that integrat- + ing Eq. (32) leads to a symmetric stationary point where +4. Dyadic EP z 0 = z ′0 . Finally, we underline that the modified varia- + tional update rule in Eq. (34) is equivalent to the standard +We now introduce Dyadic EP (Algorithm 2), a variational symmetric EP update rule in Eq. (8) (see Appendix D). +algorithm that computes the exact cost gradient in the limit + Now, to make this concrete, consider a continuous Hopfield +of infinitesimal nudging. It maps the original n-variable + network (see also Eq. (35)) with an asymmetric connection +dynamics F (x, θ) onto a 2n-variable system (z, z ′ ) defined + matrix J. After some calculations (see Appendix F), the +by an energy H(z, z ′ , θ) and cost D(z, z ′ ). We show in + augmented energy of the system can be re-expressed as: +Appendix E that AsymEP can be seen as the first-order +projection of Dyadic EP onto the original n-dimensional 1 1 + HT = − ρ(z)⊤ Sρ(z) + ρ(z ′ )⊤ Sρ(z ′ ) − ρ(z)⊤ Aρ(z ′ ) +state space. 2 2 + 1 β +The new system is defined by the energy H and cost function + (∥z∥ − ∥z ∥ ) + (C(z, y) + C(z ′ , y)) , + 2 ′ 2 + 2 2 +D, given in terms of F and C by: (29) + + z + z′ + where S and A are the symmetric and antisymmetric parts + H(z, z ′ , θ) = −(z − z ′ )⊤ F ,θ , of J, respectively and ρ is an element-wise non-linearity. + 2 + ′ + An interesting analogy can be drawn with standard learning + z+z + D(z, z ′ ) = C , (26) rules in discrete Hopfield networks (Hopfield, 1982). For + 2 a sequence of binary memories {ξ 1 , . . . , ξ m } where ξ µ ∈ +where z, z ′ ∈ Rn . In order to learn, we introduce the aug- {−1, 1}n , S P corresponds to the standard autoassociative +mented energy Hebbian rule µ ξ µ (ξ µ )⊤ , creating stable attractors at each + pattern. In contrast, A corresponds to the heteroassociative + HT (z, z ′ , θ, β) = H(z, z ′ , θ) + βD(z, z ′ ). (27) + rule (e.g., a cycle between ξ µ and ξ ν given by ξ ν (ξ µ )⊤ − +The equilibrium configuration corresponds to a saddle point ξ µ (ξ ν )⊤ ), encoding transitions between patterns. +of HT , where z minimizes and z ′ maximizes the energy. + For this specific energy, the update rule given by Eq. (34) +This poses no issue for EP, which requires only that the + can be re-expressed as: +joint state (z, z ′ ) reaches a stationary state. Although this +min-maximization can be interpreted as z evolving forward 1 ⊤ + ρ(z ′β ) − ρ(z β ) ρ(z ′β ) + ρ(z β ) . (30) + + ∆J ∝ − +and z ′ backward in time, in practice they evolve forward 2β +simultaneously, as we integrate the coupled equations: In the limit β → 0, this gives: + z + z′ + + dz ∂HT β + ! + =− =F ,θ d + dt ∆J ∝ ⊙ ρ′ (m)ρ(m)⊤ . (31) + ∂z ⊤ 2 + β + z − z′ β ∂C z + z ′ + + ∂F + + − , + 2 ∂z z+z′ 2 ∂z 2 + 2 matching the learning rule in (Pineda, 1987), with + β + + dz ′ ∂HT + + z + z′ + limβ→0 dβ being the error signal. + =+ ′ + =F ,θ + dt ∂z 2 + ′ ⊤ + β ∂C z + z ′ + + z−z ∂F 5. Numerical Experiments + − + , + 2 ∂z ′ z+z′ 2 ∂z ′ 2 + 2 In this section, we numerically validate AsymEP (Algo- + (28) rithm 1). The neuronal dynamics follows the one introduced + + 5 + Equilibrium Propagation for Non-Conservative Systems + +Algorithm 2 Dyadic EP where ∥ · ∥F denotes the Frobenius norm. Note that this + 1: Inputs: Force field F (x, θ), cost function C(x, y), metric does not capture the asymmetry of the Jacobian, + nudging parameter β, learning rate ϵ which depends on the state x. + 2: repeat For numerical experiments, we restricted the network to a + 3: 1. Free Phase: Evolve to stationary state layered architecture with a single hidden layer to facilitate + 4: Evolve the system dynamics, starting from identi- comparison with prior work. Accordingly, J in contains + cal initial conditions z(0) = z ′ (0) = z0 , only input-to-hidden connections, while J dyn is block off- + 5: diagonal, encoding bidirectional interactions between the + dz ∂H dz ′ ∂H + =− , =+ ′, (32) hidden and output layers. Both J in and J dyn are trained. + dt ∂z dt ∂z + We first use MNIST (LeCun, 1998) (60k train, 10k test) + 6: until stationary states z 0 , z ′0 are reached. + followed by Fashion-MNIST to validate AsymEP, and then + 7: 2. Nudged Equilibrium + we further validate AsymEP and Dyadic EP by comparing + 8: Evolve the system dynamics, starting from the + them to Backpropagation on a convolutional feedforward, + solution of the free phase z 0 = z ′0 : + with CIFAR-10. Inputs are normalized using min-max to + 9: + dz ∂HT dz ′ ∂HT [−1, 1] and targets are one-hot encoded in {−1, 1}. All + =− , =+ , (33) hyperparameters are detailed in Appendix G, along with + dt ∂z dt ∂z ′ + additional details and numerical results. +10: until two nudged stationary states z β , z ′β are + reached. + 5.1. Symmetric Initialization +11: 3. Parameter Update +12: Update the parameters according to: We start by comparing AsymEP with standard EP and +13: VF. All algorithms are initialized with an identical sym- + 1 ∂H(z β , z ′β , θ) + + ∆θ = −ϵ (34) metric matrix J dyn . EP maintains this symmetry through- + β ∂θ out training, while VF and AsymEP induce asymmetry in +14: until convergence of θ J dyn . Since EP and VF already achieve strong performance +15: Output: Optimized parameters θ. on MNIST, the purpose of this experiment is to validate + AsymEP and compare it against EP and VF rather than + outperform the state of the art. +in (Scellier &Bengio, 2017), and is generalized to allow Figure 1 compares the three algorithms as a function of +for non-reciprocal forces as in (Scellier et al., 2018). For hidden-layer dimension after 1 and 20 epochs. AsymEP +clarity, we express the forces in a form that explicitly sepa- consistently outperforms the baselines, suggesting it learns +rates the contributions of the external input and the recurrent faster and better. +interactions: + Figure 2 studies the evolution of the asymmetry ratio rstr . + F (x) = ρ′ (x) ⊙ J in u + J dyn ρ(x) − x, + + (35) The results are reported for 50 hidden neurons. As expected, + EP preserves the initial weight symmetry. In contrast, VF +where u ∈ RNin denotes the input and x ∈ RNdyn the neu- and AsymEP induce non-trivial evolution of rstr following +ronal state, comprising both hidden and output units. The two distinct patterns, resulting in three distinct network +matrices J in ∈ RNdyn ×Nin and J dyn ∈ RNdyn ×Ndyn define the configurations. A complementary figure is available in Ap- +input and recurrent connectivity, respectively. The activation pendix G.1. +function ρ(·) is taken to be the hyperbolic tangent, applied +element-wise. 5.2. Fixed Asymmetry Ratio +If J dyn is symmetric, we can define the energy: While the previous section focused on networks compatible + 1 1 with all three algorithms (EP, VF, AsymEP), we now turn + E(x) = ∥x∥2 − ρ(x)⊤ J dyn ρ(x) − ρ(x)⊤ J in u, (36) to architectures with strong structural asymmetry. In this + 2 2 + regime, EP is inapplicable by construction, and, as we show, +which is identical to that of (Scellier &Bengio, 2017), pro- VF performs poorly, contrary to AsymEP which remains +vided that the input neurons are activated as ρ(u). effective. +Equation (35) naturally motivates a quantitative measure of To this end, we consider a class of networks where the +structural asymmetry rstr , defined as: asymmetry ratio rstr defined in Eq. (37) is kept fixed. Let S̃ + ⊤ and à be arbitrary symmetric and antisymmetric matrices + ∥(J dyn − J dyn )/2∥F in RNdyn ×Ndyn respectively. We enforce a fixed rstr via the + rstr = , (37) + ∥J dyn ∥F + + 6 + Equilibrium Propagation for Non-Conservative Systems + + where γ ∈ R is a learnable global scale. + Using VF and AsymEP, we train a layered network with one + hidden layer of 50 neurons (in which case S̃ and à are block + off-diagonal) for different values of rstr to investigate the + impact of structural asymmetry. We compare two training + regimes: training only the input weights J in (and the scale + γ), versus training all parameters including J dyn . The first + regime trains only the external forces from the input ρ′ (x) ⊙ + J in u (which correspond to a symmetric contribution in the + Jacobian) applied to our non-conservative system, while + the second additionally trains J dyn and therefore the non- + (a) Results after one epoch. symmetric part of the Jacobian directly. + Figure 3 summarizes the results. We find that AsymEP + maintains robust performance across all asymmetry levels + (e.g., achieving an accuracy of 93.8 ± 0.4% at rstr = 0 and + 94.9 ± 0.2% at rstr = 0.875 when training all parameters) + and can even learn when the recurrent connection matrix + J dyn is completely antisymmetric (rstr = 1). Additionally, + training all parameters shows significant improvement over + training only J in . + In contrast, VF performs well at low asymmetry ratios + but degrades as asymmetry increases, eventually dropping + to chance levels (e.g., accuracies of 5 ± 3% and 8 ± 4% + (b) Results after 20 epochs. at rstr = 1 for input-only and all-parameter training, re- +Figure 1. Comparison of algorithm performance on MNIST using spectively). When only J in is trained, VF accuracy col- +a layered architecture with one hidden layer and symmetric ini- lapses around rstr ≈ 0.5, whereas training all parameters +tialization. Squares denote AsymEP, circles EP, and triangles VF. delays this collapse until rstr ≈ 0.8. Our analysis in Ap- +Test accuracy (averaged over 10 runs) is shown after one epoch + pendix G.2.1 reveals that VF adjusts the dynamics such that +(Fig. 1a) and 20 epochs (Fig. 1b). + the asymmetry of the Jacobian’s off-diagonal terms remains + strictly lower than the structural asymmetry ratio. The train- + ing appears to adjust the neuronal state such that neurons + connected by strongly asymmetric weights have low activa- + tion. As shown in Appendix G.2.1, AsymEP learns faster + than VF across all levels of asymmetry. + Finally, Appendix G.3 opens with a brief theoretical dis- + cussion of the stability of these non-conservative dynamics, + followed by simulations on all-to-all topologies with con- + strained rstr and input projections J in . Even in this worst- + case setting, AsymEP reduces oscillations and improves + stability. + + 5.3. Feedforward Architectures +Figure 2. Evolution of the asymmetry ratio rstr (defined in Eq. (37)) We now consider a purely feedforward architecture. Here +during training on MNIST for AsymEP, EP and VF, initialized +from a symmetric configuration. The models use 50 hidden neu- VF trains only the last layer: with no backward connections, +rons. the output nudging signal cannot reach earlier layers, so for + every layer but the last the nudged stationary states coincide + with the free states, giving zero weight updates. As only +following parameterization of the recurrent parameters: the output layer is trained, the system essentially becomes + "q # an Extreme Learning Machine (Huang et al., 2006; Wang + dyn 2 S̃ à et al., 2022). In contrast, AsymEP introduces a correction + J =γ 1 − rstr + rstr , (38) that generates effective backward connections, allowing the + ∥S̃∥F ∥Ã∥F + + 7 + Equilibrium Propagation for Non-Conservative Systems + + tivity structures inspired by (Millidge et al., 2023), while + keeping the number of trainable parameters fixed. + Experiments are conducted on Fashion-MNIST using a two- + hidden-layer network with hidden dimensions 500 and 200. + Network states are denoted (x0 , x1 , x2 , x3 ), where x0 is + the input and x3 = xL the output. Forward and backward + connections are denoted by Wk and Bk , respectively, with + W1 = J in . + We consider three classes of dynamics. First, the Continuous + Hopfield (CH) dynamics introduced previously: + dxk + = −xk +ρ′ (xk )⊙ Wk ρ(xk−1 )+(1−δk,L )Bk ρ(xk+1 ) . + dt + (40) + Second, Predictive Coding (PC) dynamics, defined through + the prediction errors ek = xk − Wk ρ(xk−1 ), whose fixed +Figure 3. Impact of the structural asymmetry ratio rstr on accuracy +(top) and standard deviation over 10 runs (bottom) on MNIST. point ek = 0 corresponds to a standard feedforward net- +We compare VF (orange) and AsymEP (blue) under two training work: +regimes: training only J in (dashed) or all parameters (solid). + dxk + = −ek + (1 − δk,L ) (ρ′ (xk ) ⊙ (Bk ek+1 )) . (41) + dt +nudging signal to influence all layers. We make this explicit + Third, a standard dynamics chosen for direct comparison +for a network with one hidden layer. + with backpropagation: +Let the state x be partitioned in hidden h and output o + dxk +layers. The recurrent connection matrix is then J dyn = = −xk + Wk ρ(xk−1 ) + (1 − δk,L )Bk ρ(xk+1 ). (42) + dt + + 0 0 + . The forces of the system are: + Wh→o 0 For each dynamics, we examine three connectivity scenar- + β ios. + Fh = ρ′ (h) ⊙ J in u + λ(Wh→o )⊤ (o − o0 ) − h + + + + + 0 + + ⊤ + Fo = ρ′ (o) ⊙ Wh→o ρ(h) − λWh→o (h − h ) • In the asymmetric case (Bk ̸= Wk+1 + β + ), the backward + (39) + weights Bk are randomly initialized and kept fixed + + ∂C while only the forward weights are trained, ensuring a + + λβ −o + + + ∂o fair comparison (i.e., identical number of parameters); +where λ is 0 during the free inference and 1 during the in PC, the learning rule for Bk is zero when only inputs +nudged phase (Eq. 20). The force on the hidden layer Fhβ are clamped. +now depends on the output layer through the term ρ′ (h) ⊙ ⊤ + ⊤ • In the symmetric / conservative case (Bk = Wk+1 ), the +(Wh→o ) (o − o0 ), enabling the nudge (the term β ∂C ∂o ) to CH and PC dynamics derive from an energy functional, +influence the hidden layer. This implicitly assumes that the while the standard dynamics remains non-conservative +hardware implementation supports the physical activation due to its non-symmetric Jacobian. +of these backward connections. + • In the feedforward case (Bk = 0), the PC and stan- +We validate this using a single hidden layer of only 20 neu- + dard dynamics coincide; for the standard dynamics, the +rons on MNIST. After training, VF saturates with 64.3 ± + AsymEP learning rule mirrors backpropagation, with +2.0% accuracy, whereas AsymEP reaches 92.7 ± 0.5% ac- + ∆xβk = 2β1 + (xβ − x−β ) acting as the propagated error +curacy. We expect this discrepancy to increase with network + signal. +depth, since this increases the number of layers unable to +learn under VF. A figure with the accuracy during training +can be found in Appendix G.4.2. Table 1 shows that AsymEP consistently outperforms VF + in both asymmetric and feedforward settings, in final ac- + curacy, learning speed, and stability. After a single epoch +5.4. Advantages of Non-Conservative Dynamics + it already provides on average a 15% accuracy gain with +AsymEP is not tied to a specific neural dynamics. To further an order-of-magnitude reduction in variance. Remarkably, +assess the benefits of training non-conservative dynamics AsymEP with asymmetric connectivity also surpasses EP +using AsymEP, we compare several dynamics and connec- on symmetric networks despite training only the forward + + 8 + Equilibrium Propagation for Non-Conservative Systems + +weights, suggesting that relaxing symmetry constraints may 6. Discussion and Conclusion +improve expressivity. Supplementary results are provided +in Appendix G.5. In this work, we extended Equilibrium Propagation (EP) + to non-conservative systems that reach stationary states by + deriving two mathematically equivalent algorithms that re- + cover the exact gradient of the cost function in the limit of +Table 1. Test accuracy on Fashion-MNIST (%) at Epoch 50 (mean +± std 10 runs). BP on a standard feedforward architecture using + infinitesimal nudging. +MSE and SGD achieve 87.37 ± 0.29%. The first approach, Asymmetric EP, preserves the original + inference dynamics. It introduces a corrective force during + EP AsymEP VF + the nudged phase that remains spatially local, as the anti- + Asym - 86.78 ± 0.14 85.20 ± 0.12 symmetric Jacobian is null for unconnected neurons and the + CH Feedfor - 86.05 ± 0.12 77.76 ± 0.37 + Sym 84.30 ± 0.13 - - + perturbation from equilibrium is available at the synapse + Asym - 86.20 ± 0.17 80.71 ± 6.17 level. Unlike standard methods like Recurrent Backpropa- + PC gation (Almeida, 1990; Pineda, 1987), this avoids explicit + Sym 84.78 ± 0.14 - - + Asym - 82.91 ± 0.48 75.52 ± 1.69 digital weight transposition. However, a physical mech- + Standard + Feedfor - 86.25 ± 0.16 78.58 ± 0.28 anism to obtain the local corrective force at the synapse + level remains a subject for future work. We also note that + AsymEP shares the temporal non-locality of standard EP. +Finally, to investigate how AsymEP scales with depth, we The second approach, Dyadic EP, doubles the state space +trained deeper fully connected networks with two and three to map non-reciprocal dynamics onto an energy land- +hidden layers of 500 neurons on Fashion-MNIST, reaching scape—conceptually reminiscent of multi-compartment cor- +86.41 ± 0.22% and 87.8 ± 0.15% test accuracy respectively. tical neurons, where apical dendrites integrate feedback + (analogous to z − z ′ ) separately from basal feedforward +5.5. Feedforward Training on CIFAR-10: BP vs. Dyadic input (analogous to z + z ′ ) (Guerguiev et al., 2017). Addi- + EP vs. AsymEP tionally, this expanded space also enables the positive and + negative nudging phases to run in parallel. This offers a +To test whether our framework scales beyond shallow net- pathway to implement a version of EP that is local in time, +works, we conclude with a deep, purely feedforward CNN but would require a doubling of the degrees of freedom +architecture trained on CIFAR-10. We compare backprop- on the physical hardware. More fundamentally, the energy +agation (BP), VF, AsymEP and Dyadic EP in a controlled defined on the extended state shows that the tools and the- +setting where the gradient estimator is the only difference oretical guarantees obtained for EP should also apply to +between runs: all methods share the same configuration, the case of non-reciprocal forces, and that the variational +with the BP gradient replaced by the contrast of stationary principle behind EP is universal in the sense that it can be +states for the EP-based methods (see App. G.6 for details). applied to all networks which operate in a stationary state. +Each configuration is trained for 40 epochs over 5 seeds. + Furthermore, Dyadic EP is not restricted to the EP com- +Table 2 reports the final test accuracy. Both of our algo- munity and could suggest a more physically plausible al- +rithms scale to this regime, closely tracking the BP baseline ternative to the stationary-state Adjoint Method (for fixed +throughout training and matching its final accuracy: a paired inputs) (Chen et al., 2018): by solving the forward and ad- +t-test finds no significant difference between Dyadic EP and joint equations simultaneously via relaxation, it circumvents +BP (p = 0.75), and only a sub-percent gap for AsymEP. a separate backward-in-time pass. +In contrast, VF makes slight initial progress (peaking near +30%) before collapsing to chance level (10%). Additional Finally, our experiments on MNIST, Fashion-MNIST, and +details can be found in Appendix G.6 CIFAR-10 confirm that AsymEP and Dyadic EP consis- + tently outperform EP and VF, and notably enables effective + training of feedforward networks. + Our work thus opens new avenues for learning in neuro- +Table 2. Test accuracy on CIFAR-10 (%) at epoch 40 (mean ± std +over 5 seeds). morphic hardware, dissipative physical systems, and neural + architectures where asymmetry is intrinsic rather than inci- + Method Test Acc. (%) dental. + Backpropagation 90.66 ± 0.25 + Dyadic EP 90.69 ± 0.14 + AsymEP 89.74 ± 0.14 + VF 10.00 ± 0.00 + + + 9 + Equilibrium Propagation for Non-Conservative Systems + +Impact Statement References +This paper presents results that advance the field of machine Almeida, L. B. A learning rule for asynchronous percep- +learning. There are many potential societal consequences trons with feedback in a combinatorial environment. In +of our work, none of which we feel must be specifically Artificial neural networks: concept learning, pp. 102–111. +highlighted here. 1990. + Altman, L. E., Stern, M., Liu, A. J., and Durian, D. J. Ex- +Acknowledgments perimental demonstration of coupled learning in elastic + networks. Physical Review Applied, 22(2):024053, 2024. +AES is fully funded by the Horizon Europe Marie +Skłodowska-Curie Doctoral Network ’Postdigital Plus’ Aykroyd, C., Bourgoin, A., and Poncin-Lafitte, C. L. Hamil- +(Grant 101169118). DVA acknowledges the support of tonian treatment of non-conservative systems. arXiv +the French Community of Belgium through a FRIA fellow- preprint arXiv:2507.18658, 2025. +ship. SM acknowledges financial support by the Fonds de la + Bateman, H. On dissipative systems and related variational +Recherche Scientifique–FNRS, Belgium under EOS Project + principles. Physical Review, 38(4):815, 1931. +No. 40007536. Computational resources have been pro- +vided by the Consortium des Équipements de Calcul Intensif Berneman, M. and Hexner, D. Equilibrium propagation for +(CÉCI), funded by the Fonds de la Recherche Scientifique dissipative dynamics. Advanced Intelligent Systems, pp. +de Belgique (F.R.S.-FNRS) under Grant No. 2.5020.11 and e202501310, 2025. +by the Walloon Region. + Bishop, K. J., Biswal, S. L., and Bharti, B. Active colloids + as models, materials, and machines. Annual Review of + Chemical and Biomolecular Engineering, 14(1):1–30, + “ἁρμονίη ἀφανὴς φανερῆς κρείττων” + 2023. + Bowick, M. J., Fakhri, N., Marchetti, M. C., and Ra- + maswamy, S. Symmetry, thermodynamics, and topology + in active matter. Physical Review X, 12(1):010501, 2022. + Brandenbourger, M., Locsin, X., Lerner, E., and Coulais, C. + Non-reciprocal robotic metamaterials. Nature communi- + cations, 10(1):4608, 2019. + Cesa-Bianchi, N. and Lugosi, G. Prediction, learning, and + games. Cambridge university press, 2006. + Chen, R. T., Rubanova, Y., Bettencourt, J., and Duvenaud, + D. K. Neural ordinary differential equations. Advances + in neural information processing systems, 31, 2018. + Cin, N. D., Marquardt, F., and Wanjura, C. C. Training + nonlinear optical neural networks with scattering back- + propagation. arXiv preprint arXiv:2508.11750, 2025. + Costa, P. and Santos, P. A. Directed equilibrium propagation + revisited. Mathematics, 13(11), 2025. ISSN 2227-7390. + Crick, F. The recent excitement about neural networks. + Nature, 337, 1989. + Dillavou, S., Stern, M., Liu, A. J., and Durian, D. J. Demon- + stration of decentralized physics-driven learning. Physi- + cal Review Applied, 18(1):014040, 2022. + Dillavou, S., Beyer, B. D., Stern, M., Liu, A. J., Miskin, + M. Z., and Durian, D. J. Machine learning without a pro- + cessor: Emergent learning in a nonlinear analog network. + Proceedings of the National Academy of Sciences, 121 + (28):e2319718121, 2024. + + 10 + Equilibrium Propagation for Non-Conservative Systems + +Ernoult, M., Grollier, J., Querlioz, D., Bengio, Y., and Scel- Indiveri, G. and Liu, S.-C. Memory and information pro- + lier, B. Updates of equilibrium prop match gradients cessing in neuromorphic systems. Proceedings of the + of backprop through time in an rnn with static input. IEEE, 103(8):1379–1397, 2015. + Advances in neural information processing systems, 32, + 2019. Kalinin, K. P., Gladrow, J., Chu, J., Clegg, J. H., Cletheroe, + D., Kelly, D. J., Rahmani, B., Brennan, G., Canakci, B., +Ernoult, M., Grollier, J., Querlioz, D., Bengio, Y., and Scel- Falck, F., et al. Analog optical computer for ai inference + lier, B. Equilibrium propagation with continual weight and combinatorial optimization. Nature, 645(8080):354– + updates. arXiv preprint arXiv:2005.04168, 2020. 361, 2025. +Falk, M. J., Strupp, A. T., Scellier, B., and Murugan, Kendall, J., Pantone, R., Manickavasagam, K., Bengio, + A. Temporal contrastive learning through implicit non- Y., and Scellier, B. Training end-to-end analog neural + equilibrium memory. Nature Communications, (16), networks with equilibrium propagation. arXiv preprint + 2025. arXiv:2006.01981, 2020. +Farinha, M. T., Pequito, S., Santos, P. A., and Figueiredo, + Laborieux, A. and Zenke, F. Holomorphic equilibrium + M. A. T. Equilibrium propagation for complete directed + propagation computes exact gradients through finite size + neural networks. In Proceedings of the 28th European + oscillations. Advances in Neural Information Processing + Symposium on Artificial Neural Networks, Computational + Systems, 35:12950–12963, 2022. + Intelligence and Machine Learning (ESANN 2020), 2020. +Galley, C. R. Classical mechanics of nonconservative sys- Laborieux, A. and Zenke, F. Improving equilibrium propa- + tems. Physical review letters, 110(17):174301, 2013. gation without weight symmetry through jacobian home- + ostasis. In Proceedings of the International Confer- +Guerguiev, J., Lillicrap, T. P., and Richards, B. A. Towards ence on Learning Representations (ICLR) 2024, Virtual + deep learning with segregated dendrites. elife, 6:e22901, (ICLR), May 2024. + 2017. + Laborieux, A., Ernoult, M., Scellier, B., Bengio, Y., Grollier, +Høier, R. and Zach, C. A lagrangian perspective on dual J., and Querlioz, D. Scaling equilibrium propagation to + propagation. In Proceedings of the First Workshop on Ma- deep convnets by drastically reducing its gradient estima- + chine Learning with New Compute Paradigms at NeurIPS tor bias. Frontiers in neuroscience, 15:633674, 2021. + 2023, New Orleans, LA, USA, Dec 2023. + Laydevant, J., Marković, D., and Grollier, J. Training an +Høier, R. and Zach, C. Two tales of single-phase contrastive + ising machine with equilibrium propagation. Nature Com- + hebbian learning. In Salakhutdinov, R., Kolter, Z., Heller, + munications, 15(1):3671, 2024. + K., Weller, A., Oliver, N., Scarlett, J., and Berkenkamp, F. + (eds.), Proceedings of the 41st International Conference LeCun, Y. The mnist database of handwritten digits. + on Machine Learning, volume 235 of Proceedings of http://yann. lecun. com/exdb/mnist/, 1998. + Machine Learning Research, pp. 18470–18488. PMLR, + 21–27 Jul 2024. Martin, E., Ernoult, M., Laydevant, J., Li, S., Querlioz, D., + Petrisor, T., and Grollier, J. Eqspike: spike-driven equi- +Høier, R., Staudt, D., and Zach, C. Dual propagation: accel- + librium propagation for neuromorphic implementations. + erating contrastive hebbian learning with dyadic neurons. + Iscience, 24(3), 2021. + In Proceedings of the 40th International Conference on + Machine Learning, ICML’23. JMLR.org, 2023. Massar, S. Equilibrium propagation for learning in la- +Høier, R., Kalinin, K., Ernoult, M., and Zach, C. Dyadic grangian dynamical systems. Physical Review E, 112 + learning in recurrent and feedforward models. In NeurIPS (3):035304, 2025. + 2024 Workshop Machine Learning with new Compute Massar, S. and Mognetti, B. M. Equilibrium propagation: + Paradigms, 2024. the quantum and the thermal cases. Quantum Studies: +Hopfield, J. J. Neural networks and physical systems with Mathematics and Foundations, 12(1):6, 2025. + emergent collective computational abilities. Proceedings + Millidge, B., Song, Y., Salvatori, T., Lukasiewicz, T., and + of the national academy of sciences, 79(8):2554–2558, + Bogacz, R. Backpropagation at the infinitesimal infer- + 1982. + ence limit of energy-based models: Unifying predictive +Huang, G.-B., Zhu, Q.-Y., and Siew, C.-K. Extreme learning coding, equilibrium propagation, and contrastive hebbian + machine: theory and applications. Neurocomputing, 70 learning. In International Conference on Learning Rep- + (1-3):489–501, 2006. resentations (ICLR), 2023. + + 11 + Equilibrium Propagation for Non-Conservative Systems + +Nest, T. and Høier, R. Dyadic learning in asymmetric Wang, Q., Wanjura, C. C., and Marquardt, F. Training + convnets. In New Frontiers in Associative Memories- coupled phase oscillators as a neuromorphic platform + Workshop at ICLR 2026. using equilibrium propagation. Neuromorphic Computing + and Engineering, 4(3):034014, 2024. +Osat, S. and Golestanian, R. Non-reciprocal multifarious + self-organization. Nature Nanotechnology, 18(1):79–85, Wanjura, C. C. and Marquardt, F. Quantum equilibrium + 2023. propagation for efficient training of quantum systems + based on onsager reciprocity. Nature Communications, +O’Connor, P., Gavves, E., and Welling, M. Training a spik- 16(1):6595, 2025. + ing neural network with equilibrium propagation. In The + 22nd international conference on artificial intelligence Werbos, P. J. Backpropagation through time: what it does + and statistics, pp. 1516–1523. PMLR, 2019. and how to do it. Proceedings of the IEEE, 78(10):1550– + 1560, 1990. +Pineda, F. Generalization of back propagation to recurrent + and higher order neural networks. In Neural information Yi, S.-i., Kendall, J. D., Williams, R. S., and Kumar, S. + processing systems, 1987. Activity-difference training of deep neural networks using + memristor crossbars. Nature Electronics, 6(1):45–51, +Pourcel, G., Basu, D., Ernoult, M., and Gilra, A. Lagrangian- 2023. + based equilibrium propagation: generalisation to arbi- + trary boundary conditions & equivalence with hamilto- + nian echo learning. arXiv preprint arXiv:2506.06248, + 2025. + +Rageau, T. and Grollier, J. Training and synchronizing + oscillator networks with equilibrium propagation. Neuro- + morphic Computing and Engineering, 2025. + +Sajnok, K. and Matuszewski, M. Near-equilibrium propaga- + tion training in nonlinear wave systems. arXiv preprint + arXiv:2510.16084, 2025. + +Scellier, B. Quantum equilibrium propagation: Gradient- + descent training of quantum systems. arXiv preprint + arXiv:2406.00879, 2024. + +Scellier, B. and Bengio, Y. Equilibrium propagation: Bridg- + ing the gap between energy-based models and backprop- + agation. Frontiers in computational neuroscience, 11:24, + 2017. + +Scellier, B., Goyal, A., Binas, J., Mesnard, T., and Bengio, + Y. Generalization of equilibrium propagation to vector + field dynamics. arXiv preprint arXiv:1808.04873, 2018. + +Scellier, B., Mishra, S., Bengio, Y., and Ollivier, Y. Agnostic + physics-driven deep learning. arXiv:2205.15021v1, 2022. + +Scurria, A. E. A physical theory of backpropagation: Exact + gradients from the least-action principle. 2026. + +Stern, M., Hexner, D., Rocks, J. W., and Liu, A. J. Su- + pervised learning in physical networks: From machine + learning to learning machines. Physical Review X, 11(2): + 021045, 2021. + +Wang, J., Lu, S., Wang, S.-H., and Zhang, Y.-D. A review + on extreme learning machine. Multimedia Tools and + Applications, 81(29):41611–41660, 2022. + + 12 + Equilibrium Propagation for Non-Conservative Systems + +A. Gradient Estimation Error in VF where s denotes the dynamical state of the system. This + symmetry is the linchpin of the equivalence proof, as the +In this appendix, we quantify the gradient estimation error gradient expressions derived for BPTT and standard EP +introduced by VF in the limit where the Jacobian asymmetry differ precisely by a transpose operation applied to ∂F + ∂s . +is small. + This observation aligns with our analysis in the main text: +Comparing the post-synaptic update terms in Eqs. (12) and VF fails in non-conservative systems due to the missing +(14) gives the following error in the gradient of the cost: transpose in the post-synaptic term (see Eq. (16)). Following + ⊤ the derivation in Ernoult et al. (2019) (viz., Appendix A, Eqs. + ∂F 0 (31–33)), the recursive relations for the gradients in BPTT + Error = − (x , θ) + ∂θ are given by: + −1 −1 ∂C 0 +× JF (x0 , θ) − JF⊤ (x0 , θ) (x , y), (43) ∂ℓ + ∂x ∇BPTT + s (0) = (s⋆ , y), (49) + ∂s +To quantify this error, we decompose the Jacobian JF (x, θ) and for all t = 1, . . . , K, +into its symmetric part SJ (x, θ) and antisymmetric part + ⊤ + ∂F + SJ (x, θ) = 12 JF (x, θ) + JF⊤ (x, θ) , ∇BPTT ∇BPTT (t − 1), + + s (t) = (x, s⋆ , θ) s (50) + (44) ∂s + AJ (x, θ) = 12 JF (x, θ) − JF⊤ (x, θ) . + + ⊤ + ∂F + ∇BPTT + θ (t) = (x, s⋆ , θ) ∇BPTT + s (t − 1), (51) +Assuming the asymmetry AJ (x, θ) is small, we can make ∂θ +a series expansion in SJ−1 AJ (omitting the dependencies where θ represents the optimization parameters, ℓ is the +for clarity). Applying the Neumann expansion for small cost function, s⋆ is the free equilibrium state (satisfying +∥SJ−1 AJ ∥ gives F (s⋆ ) = 0), y is the target, and x is the input. The index t + ∞ + ! denotes the unrolled time steps, initialized at s(0) = s⋆ . + X + (JF ) −1 + = (−1) n + (SJ−1 AJ )n SJ−1 , (45) In contrast, the gradients computed by VF follow the recur- + n=0 sion (viz., Ernoult et al. (2019), Appendix A, Eqs. (24–26)): + ∞ + ! + X + (JF⊤ )−1 = (SJ−1 AJ )n SJ−1 . (46) ∂ℓ + ∆EP + s (0) = − (s⋆ , y), (52) + n=0 ∂s +Subtracting the two series and assuming convergence, we and for all t ≥ 0, +finally obtain + ∂F + ∆EP + s (t + 1) = (x, s⋆ , θ) ∆EPs (t), (53) + ∞ ∂s + ! + X + −1 + 2n+1 + −1 ⊤ −1 + (JF ) − (JF ) = −2 SJ AJ SJ−1 . + ∂F + ⊤ + n=0 ∆EP + θ (t + 1) = (x, s ⋆ , θ) ∆EP + s (t). (54) + (47) ∂θ + Comparing these two sets of equations confirms that the only +B. Equivalence between AsymEP and BPTT difference are Eqs. (50) and (53), specifically the transpose + of the Jacobian ∂F + ∂s (ignoring the global sign difference in +In this appendix, we sketch the equivalence between the Eqs. (49) and (52)). +gradient estimate computed by AsymEP and Backpropaga- +tion Through Time (BPTT) (Werbos, 1990) for a Recurrent In AsymEP, we modify the dynamics by adding a correction +Neural Network with fixed inputs. Our derivation relies on term dependent on the antisymmetric part of the Jacobian. +the proof provided by Ernoult et al. (2019), which estab- Denoting the force of this augmented system by F A , the +lished that standard (conservative) EP computes gradients Jacobian at the free equilibrium satisfies: +identical to those of BPTT. To facilitate direct comparison, ⊤ + ∂F A + +we adopt their notation for this section. ∂F + (x, s⋆ , θ) = (x, s⋆ , θ) . (55) + ∂s ∂s +The proof provided by Ernoult et al. (2019) relies on the +assumption that the vector field F (i.e., transition function) By substituting this corrected Jacobian into the recursive +is derived from a scalar potential function, which implies relations, AsymEP recovers the exact transpose required +that by BPTT. Consequently, our method extends the equiva- + ⊤ + ∂F ∂F lence between EP and BPTT to the general case of non- + = , (48) conservative force. + ∂s ∂s + + 13 + Equilibrium Propagation for Non-Conservative Systems + +C. Out-of-Equilibrium Mechanics C.3. Symmetry Breaking as Credit Assignment +Here we sketch the physical picture behind the doubled- On the diagonal manifold z = z ′ the doubled system enjoys +energy construction of Eq. (26). The full derivation from a gauge symmetry: the auxiliary variable z ′ is redundant and +Hamilton’s least-action principle, together with its connec- the difference d is identically zero. Credit assignment is im- +tion to the Bateman–Galley formalism for non-conservative plemented by deliberately breaking this symmetry through +classical mechanics (Bateman, 1931; Galley, 2013; Aykroyd the task cost. Adding βD(z, z ′ ) = β C(m) to H exerts +et al., 2025), can be found in (Scurria, 2026). opposite forces on z and z ′ and drives them apart, so that + the difference d ceases to be redundant and begins to carry +C.1. The Helmholtz Obstruction information about the loss landscape. + +The natural physical route to a variational principle for a +dynamical system ẋ = F (x, θ) is to seek a scalar potential D. Proofs for Dyadic EP +E such that F = −∂x E. The classical Helmholtz integra- We now demonstrate that Dyadic EP correctly trains the +bility condition states that such an E exists if and only if the parameters θ of the original force field F (x, θ), giving the +Jacobian JF is symmetric everywhere. Whenever the inter- 0 + exact gradient dC(x̄ + dθ + ) + in the limit of infinitesimal nudging. +actions are non-reciprocal — as in feedforward networks, +active matter, or driven optical systems — JF acquires +a non-zero antisymmetric part and the Helmholtz condi- D.1. Proof of EP +tion fails identically. No scalar potential on the original First, recall that standard EP does not strictly require the +n-dimensional state space can then generate the dynamics, system to settle at an energy minimum; it requires only that +and the “energy minimisation” route at the heart of standard the system reaches a stationary state (a fixed point of the +EP is blocked at the structural level. The obstruction is not dynamics). Indeed, using the notation of Section 2.1, EP +a matter of computational convenience: it reflects the fact relies on the key identity: +that the rotational component of F carries information that +no scalar function of x alone can record. d2 d2 + ET (xβ , θ) = ET (xβ , θ). (57) + dθdβ dβdθ +C.2. Variational Reconstruction on a Doubled Space Expanding the total derivative with respect to β gives: +Applying the Bateman–Galley formalism circumvents this ⊤ + ∂ET (xβ , θ) dxβ ∂ET (xβ , θ) + +obstruction by enlarging the configuration space. The single d + ET (xβ , θ) = + +state x ∈ Rn is replaced by a conjugate pair (z, z ′ ) ∈ R2n , dβ ∂x dβ ∂β +and the rotational component of F — which has no scalar = C(xβ ). (58) +generator on the original n-dimensional space — is ab- +sorbed into a bilinear coupling between z and z ′ on the Where the first term vanishes because the system is at a + ∂ +doubled space, where it does admit a variational descrip- stationary state, i.e., ∂x ET (xβ , θ) = 0; this holds even if +tion. The physical motion is recovered on the diagonal the system is not at a minimum of ET . Similarly, for the +submanifold z = z ′ (the so called ’physical limit’), while derivative with respect to θ: +the off-diagonal direction d = z − z ′ supplies the additional + d ∂ET (xβ , θ) +degree of freedom needed to encode non-reciprocity. ET (xβ , θ) = , (59) + dθ ∂θ +Specializing this reconstruction to the overdamped (first- + where we additionally assume that the cost function does +order) regime relevant to relaxational neural dynamics yields + not depend explicitly on the parameters θ. Substituting these +the bilinear energy + results into Eq. (57) in the limit of infinitesimal nudging + (β → 0) recovers the fundamental relation given by Eq. (9). + z + z′ + + H(z, z ′ , θ) = −(z − z ′ )⊤ F ,θ , (56) + 2 + D.2. Proof of Dyadic EP +which is precisely Eq. (26). The symmetric midpoint m = We analyze now the stationary states of Dyadic EP by intro- +(z + z ′ )/2 plays the role of the physical coordinate of the ducing the change of variables: +doubled system, while d is the auxiliary direction along +which non-reciprocity is stored. On the submanifold z = z ′ z + z′ + m= , d = z − z′. (60) +the coupling proportional to (z − z ′ ) vanishes identically 2 +and both states evolve under the original field F , so the In these coordinates, the augmented energy HT becomes +doubling leaves the on-shell physics unchanged. We refer +the reader to (Scurria, 2026) for the full construction. HT (m, d, θ, β) = −d⊤ F (m, θ) + βC(m) (61) + + 14 + Equilibrium Propagation for Non-Conservative Systems + +and the dynamics in Eq. (28) can be rewritten as: In Dyadic EP, we instead employ the single-phase update: + + 1 ∂H(z β , z ′β , θ) + + dm ∂HT + =− = F (m, θ), (62) ∆θ ∝ − (70) + dt ∂d β ∂θ + dd ∂HT ∂ + =− = dT JF (m, θ) − β C(m). (63) This choice avoids the overhead of evolving two coupled + dt ∂m ∂m + equations in the extended space, which would be computa- + β +The stationary states (mβ , d ) are the solutions to: tionally equivalent to evolving four equations in the original + space (two for +β and two for −β). Using Eq. (70), we + F (mβ , θ) = 0, (64) evolve only one coupled equation for +β in the extended + space; this corresponds to two equations in the original + βT ∂ + d JF (mβ , θ) − β C(mβ ) = 0. (65) space, thereby achieving the same computational complex- + ∂m ity as AsymEP. Furthermore, this single-phase formulation + suggests a pathway toward making the update local in time, +This leads to the following observations: provided appropriate hardware is used to implement the +1) The stationary state of m is independent of β and coin- augmented phase. +cides with the stationary state of the original system: Mathematically, these two approaches yield the same gradi- + ent estimate because the equations for dβ are linear. Explic- + z β + z ′β + = mβ = m0 = x0 . (66) itly we have : + 2 + ∂H(z β , z ′β , θ) ∂F z β + z ′β + + = −(z β − z ′β )⊤ ,θ +2) The Jacobian of the extended system defined in Eq. (26) ∂θ ∂θ 2 +is invertible, provided JF is invertible. This is most evident ⊤ + ∂F 0 −1 +from Eq. (63). = −β z ,θ JF⊤ (z 0 , θ) + ∂θ +3) The stationary state value of d is given by: + ∂C 0 + + × (z ) , (71) + β + + −1 ∂C 0 + ∂x + d = β JF⊤ (m0 , θ) (x ) (67) + ∂x where we have used Eqs. (66) and (67). Inspection of + Eq. (71) confirms that, up to corrections of order β 2 , we + 0 +In particular, when β = 0, we have d = 0, which implies obtain exactly the same gradient as in AsymEP. +that the free stationary states coincide: z 0 = z ′0 . +4) The cost at the stationary state of the extended system E. AsymEP versus Dyadic EP +is equal to the cost at the stationary state of the original + In this appendix, we demonstrate that Asymmetric Equilib- +system: + rium Propagation (AsymEP) emerges naturally as the first- + D(m0 ) = C(x0 ). (68) + order projection of the 2N -dimensional Dyadic Equilibrium +Consequently, the gradients of the cost with respect to the Propagation onto a single N -dimensional state space. We +parameters are identical. then formalize the physical trade-offs between the two ar- + chitectures. +Since both the original and extended systems, given respec- +tively in Eq. (28) and Eq. (1-2), share the same cost at their + E.1. AsymEP as the Linear Projection of Dyadic EP +respective stationary states, and because the Jacobians of +both models are invertible, applying EP update rule to the As established in Appendix D.2, transforming the 2N - +extended system give the correct gradient estimate for the dimensional extended space (z, z ′ ) into the mean state + ′ +parameters θ of the original system. m = z+z 2 and the difference state d = z − z ′ exactly +The final step of the proof is to establish the equivalence decouples the stationary dynamics. Because the stationary +between the standard parameter update rule in Eq. (8) and state of m is the free state of the original system (mβ = x0 ), +the modified rule used by Dyadic EP in Eq. (34). Indeed, if the cost function drives the difference variable to a stationary + β +we were to apply the standard update rule in the extended state d satisfying: +space, the update would be: β ∂C 0 + JF⊤ (x0 , θ)d = β (x ) (72) + 1 + β ′β + ∂H(z , z , θ) ∂H(z , z −β ′−β + , θ) + ∂x + ∆θ ∝ − − . + 2β ∂θ ∂θ To recover this exact error signal in an N -dimensional space, + (69) we postulate a modified dynamical system FA (x) compris- + + 15 + Equilibrium Propagation for Non-Conservative Systems + +ing the standard EP dynamics and a spatial correction Γ(x): F. Derivation of the Hopfield-like Energy + ∂C In this section, we derive the explicit energy functional for + FA (x) = F (x) − β (x) + Γ(x) (73) the Continuous Asymmetric Hopfield dynamics defined in + ∂x + Eq. (35). The force field is given by: +Let ∆x = xβA − x0 denote the displacement from the +free equilibrium. Expanding the stationarity condition F (x) = ρ′ (x) ⊙ (Jρ(x)) − x. (78) +FA (xβA ) = 0 to first order around x0 yields: + We omit external inputs J in for brevity, as they appear sym- + ∂C 0 metrically in the Jacobian. The variational Hamiltonian is + JF (x0 , θ)∆x − β (x ) + Γ(xβA ) ≈ 0 (74) defined as: + ∂x + z + z′ z + z′ + +To ensure the first-order displacement matches the Dyadic H(z, z ′ ) = −(z − z ′ )⊤ F + βC . + β 2 2 +EP error signal (i.e., ∆x ≈ d ), we substitute Eq. (72) into + (79) +the expansion: + To analyze this expression, we introduce the midpoint m = + z+z ′ + Γ(xβA ) = JF⊤ (x0 , θ) − JF (x0 , θ) ∆x and the difference d = z − z ′ . Since the separation + + (75) 2 + between z and z ′ is induced solely by the nudging parameter + = −2AJ (x0 , θ)(xβA − x0 ) (76) + β, the difference scales as ∥d∥ ∼ O(β). We therefore +This uniquely recovers the AsymEP augmented dynamics. neglect terms of order O(∥d∥3 ) (i.e., or equivalently O(β 3 )) +Finally, to eliminate the O(β 2 ) error, AsymEP evaluates the as they do not contribute to the gradient of the cost. +centered difference of two opposite nudges: The activation at the midpoint can be approximated as: + + dxA ρ(z) + ρ(z ′ ) + x±β 0 + A =x ±β + O(β 2 ) (77) ρ(m) = + O(∥d∥2 ). (80) + dβ β=0 2 + Similarly, the difference in activations is: +Subtracting these states cancels the O(β 2 ) error, yielding +1 +β −β β 3 +2 (xA − xA ) = d + O(β ), successfully recovering the ρ(z) − ρ(z ′ ) = ρ′ (m) ⊙ d + O(∥d∥3 ). (81) +exact post-synaptic update term. + Inverting this relation, we express the state difference as: +E.2. Physical Trade-offs and the Extended Space + z − z ′ = (ρ(z) − ρ(z ′ )) ⊙ ρ′ (m) + O(∥d∥3 ). (82) +We can view AsymEP and Dyadic EP as a space-time trade- +off of the same underlying physical optimization problem. + We substitute these expansions into the interaction term +AsymEP preserves the original N -dimensional state space of the Hamiltonian, Hint = −(z − z ′ )⊤ (ρ′ (m) ⊙ Jρ(m)). +of the network at the cost of temporal non-locality. The sys- Applying the identity a⊤ (b ⊙ c) = (a ⊙ b)⊤ c, we obtain: +tem must evolve sequentially, requiring physical memory + ⊤ +not only to store the free equilibrium x0 for the asymmet- Hint = − ((z − z ′ ) ⊙ ρ′ (m)) Jρ(m) +ric correction, but also to store the successive stationary + ρ(z) + ρ(z ′ ) + +states required to evaluate the contrastive gradient update. ≈ −(ρ(z) − ρ(z ′ ))⊤ J . (83) + 2 +AsymEP thus serves as the direct, spatially minimal exten- +sion of EP. Expanding the product gives: +Dyadic EP provide a learning signal that is local in both + 1h +space (where z − z ′ encodes the gradient) and time (allow- Hint = − ρ(z)⊤ Jρ(z) + ρ(z)⊤ Jρ(z ′ ) + 2 +ing the nudged phases to execute in parallel) at the cost i +of doubling the state space. In particular, capturing non- − ρ(z ′ )⊤ Jρ(z) − ρ(z ′ )⊤ Jρ(z ′ ) . (84) +conservative forces in this extended space requires a spe- +cific bilinear coupling, rather than a trivial superposition We decompose the connectivity matrix J into its symmetric +of uncoupled subsystems. It can be seen as a blueprint for part S and antisymmetric part A. The first and last terms +future neuromorphic hardware. simplify to ρ(z)⊤ Sρ(z). The cross terms satisfy: +Ultimately, the reduction of Dyadic EP to AsymEP via the ρ(z)⊤ Jρ(z ′ ) − ρ(z ′ )⊤ Jρ(z) = ρ(z)⊤ (J − J ⊤ )ρ(z ′ ) +variables m and d proves the universality of EP’s variational +principle. = ρ(z)⊤ (2A)ρ(z ′ ). (85) + + 16 + Equilibrium Propagation for Non-Conservative Systems + +Thus, the interaction term reduces to: The input parameters are then updated using the standard + learning rule (21). In particular, the presynaptic term associ- + 1 1 + Hint = − ρ(z)⊤ Sρ(z) + ρ(z ′ )⊤ Sρ(z ′ ) ated with the input weights is given by, + 2 2 + − ρ(z)⊤ Aρ(z ′ ) + O(∥d∥3 ). (86) ∂Fi + in + = δik ρ′ (xi )ul . (93) + ∂Jkl +Finally, for the nudging term, we expand the cost function The presynaptic terms associated with the dynamical param- + dyn +around the midpoint: eters Jij depend on the experiment. + + 1 + C(m) = (C(z) + C(z ′ )) + O(∥d∥2 ). (87) G.1. Symmetric Initialization + 2 + G.1.1. L EARNING RULES +When multiplying by β, the remainder term becomes β · +O(∥d∥2 ). Since ∥d∥ ∼ O(β), this remainder is of order For clarity, we write the learning rules for VF and AsymEP. +O(β 3 ) and can be consistently discarded alongside the third- For the input weights, using (93), we have: +order terms from the interaction expansion. 1 h +β i + in + ∆Jik ∝ (xi − x−β ′ 0 + i )ρ (xi )uk , (94) +Combining all these components, the final Hamiltonian is: 2β + + 1 1 while for the recurrent weight, we get: + H(z, z ′ ) = − ρ(z)⊤ Sρ(z) + ρ(z ′ )⊤ Sρ(z ′ ) + 2 2 1 h +β i + 1 ∆Jijdyn + ∝ (xi − x−β i )ρ′ 0 + (xi )ρ(x 0 + j ) . (95) + − ρ(z) Aρ(z ) + (∥z∥2 − ∥z ′ ∥2 ) + ⊤ ′ 2β + 2 + β ′ + For EP, we have: + + (C(z) + C(z )). (88) + 2 1 h +β i + in + ∆Jik ∝ ρ(xi ) − ρ(x−β + i ) uk , (96) +The saddle-point dynamics, given by Eq. 32, generated by 2β +this Hamiltonian are: and for the recurrent weights: + dz β ∂C 1 h +β + = ρ′ (z) ⊙ (Sρ(z) + Aρ(z ′ )) − z − + i + , dyn −β −β + dt 2 ∂z + (89) ∆Jij ∝ ρ(xi )ρ(x+βj ) − ρ(xi )ρ(xj ) . (97) + 2β + ′ + dz β ∂C + = ρ′ (z ′ ) ⊙ (Sρ(z ′ ) + Aρ(z)) − z ′ + . (90) + dt 2 ∂z ′ G.1.2. S UPPLEMENTARY N UMERICAL R ESULTS +This system recovers the original continuous Hopfield dy- To complement Fig. 2, we report the evolution of the accu- +namics when z = z ′ (assuming β = 0). racy of the three methods in Fig. 4. We consider a layered + network with 50 hidden neurons. While this capacity is +G. Experimental Details insufficient for state-of-the-art performance, it amplifies the + difference in accuracy between models to aid visualization. +As in the main text, the neuronal dynamics are governed by Models are trained for 20 epochs starting from a symmetric +the vector field: configuration, the natural setting for both VF and EP. With + this initialization, AsymEP consistently outperforms the + X dyn other methods and learns faster by exploiting the additional + Fi = ρ′ (xi ) Jij ρ(xj ) + bi (u) − xi , (91) + degrees of freedom of the asymmetric network. + j + + +where the input-dependent bias bi (u) is precomputed for G.2. Fixed Asymmetry Ratio +each MNIST input u as: This section details the implementation for the fixed asym- + X metry ratio experiments presented in Section 5.2, followed + bi (u) = Jilin ul . (92) by complementary numerical results regarding learning + l∈in speed and induced Jacobian asymmetry. +This term projects the input space into the recurrent sub- + G.2.1. L EARNING RULES +space. The bias yields a diagonal contribution to the Jaco- +bian JF = ∂F ∂x , and therefore does not contribute to the Parametrization and notation. To enforce a fixed asym- +antisymmetric correction used in the augmented dynamics metry ratio, we explicitly parameterize the independent ele- +Eq. (20) of AsymEP. ments of Eq. (38). We introduce two parameter vectors θS + + 17 + Equilibrium Propagation for Non-Conservative Systems + + Parameter Sym. Init. / Feedforward Fixed rstr Fixed rstr & rin + sec. 5.1 & 5.3 sec. 5.2 app. G.3 + Learning Rate (Input-Hidden) 0.05 0.05 0.0125 + Learning Rate (Hidden-Output) 0.01 0.01 0.0025 + Time Step (Dynamics Integration) 0.5 0.3 0.3 + Nudging Parameter (β) 0.5 0.5 0.5 + Free-phase Steps (nfree ) 20 30 40 + Nudged-phase Steps (nnudge ) 10 10 10 + Number of Epochs 40 / 20 30 40 + Batch Size 64 √64 √64 + Scaling Parameter γ n.a. 60 60 + Structure 784 - n.a. -10 784-50-10 all-to-all, 500 hid + Activation function ρ tanh tanh tanh + Initial Recurrent State s s ∼ U (−1, 1) s ∼ U(−1, 1) s ∼ U(−1, 1) + Initial Parameters θ θ ∼ N (0, N1 ) θ ∼ N (0, N1 ) θ ∼ N (0, N1 ) + Number of Runs (training + inference) 10 10 10 +Table 3. Trained Model Hyperparameters on MNIST. N is the total number of neurons, U(−1, 1) is a uniform distribution, and N (µ, σ 2 ) +is a Gaussian distribution. For the rstr parametrization, we choose more cautious hyperparameters for training and inference compared to +the symmetric initialization, due to increasingly non-conservative and potentially oscillatory dynamics. + + + + elements of S̃, the full matrices are constructed as: + S + S̃ij = δij ξi + (1 − δij )θk(max(i,j),min(i,j)) , (99) + A + Ãij = ϵij θk(max(i,j),min(i,j)) , (100) + where ϵij is the Levi-Civita symbol. The dynamical param- + eters are then given by: + dyn + Jij = γ(cS S̃ij + cA Ãij ), (101) + with normalization coefficients + p + 2 + 1 − rstr rstr + cS = , cA = , (102) + FS FA + defined in terms of the Frobenius norms: + v + uN M + uX X 2 +Figure 4. Evolution of the mean accuracy and standard deviation F =t + S ξ2 + 2i θS , k (103) +(over 10 runs) during training on MNIST for AsymEP, EP, and VF. i=1 k=1 +Models use 50 hidden neurons. v + u M + u X 2 + FA = t2 θkA . (104) + k=1 +and θA of size M = Ndyn (Ndyn − 1)/2, which encode the +off-diagonal elements of the symmetric and antisymmetric Presynaptic computation. The dependence of the nor- +components S̃ and Ã, respectively. The correspondence malization coefficients on the parameters introduces addi- +between matrix and vector indices is given by: tional regularization terms in the learning rule compared + to the parameterization of (Scellier &Bengio, 2017). The + (i − 1)(i − 2) gradients of the normalization coefficients are: + k(i, j) = + j, (1 ≤ j < i ≤ Ndyn ) + 2 ∂cS θkS ∂cS ξm + (98) = −2cS 2, = −cS 2, (105) + ∂θkS (FS ) ∂ξm (FS ) + ∂cA θA +where the condition j < i selects the strictly lower triangular A + = −2cA k 2 . (106) +elements. Introducing an additional vector ξ for the diagonal ∂θk (FA ) + + 18 + Equilibrium Propagation for Non-Conservative Systems + + Parameter Comparison Dyn. 2 hidden layers 3 hidden layers + sec. 5.4 sec. 5.4 sec. 5.4 + Learning Rate (Input-Hidden) 0.0016 0.0013 0.6 + Learning Rate (Hidden-Hidden) 0.0016 0.0013 0.6 + Learning Rate (Hidden-Output) 0.0016 0.0013 0.6 + Time Step (Dynamics Integration) 0.4 0.3 0.0075 + Nudging Parameter (β) 0.3 0.5 0.20 + Free-phase Steps (nfree ) 40 40 60 + Nudged-phase Steps (nnudge ) 20 20 30 + Number of Epochs 50 40 40 + Batch Size 64 64 64 + Layer Structure 784-500-200-10 784-500-500-10 784-500-500-500-10 + Activation function ρ tanh tanh tanh + Initial Recurrent State s s ∼ U(−1, 1) s ∼ U(−1, 1) s ∼ U (−1, 1) + Initial Parameters θ θ ∼ N (0, N1 ) θ ∼ N (0, N1 ) θ ∼ N (0, N1 ) + Number of Runs (training + inference) 10 10 10 +Table 4. Trained Model Hyperparameters on Fashion-MNIST. N is the total number of neurons, U(−1, 1) is a uniform distribution, and +N (µ, σ 2 ) is a Gaussian distribution. For the rstr parametrization, we choose more cautious hyperparameters for training and inference +compared to the symmetric initialization, due to increasingly non-conservative and potentially oscillatory dynamics. + + + +Combining these with the derivatives of the matrices S̃ and (where p > q): +Ã, we have: + N + θkA X + + ∂Fi ′ + ∂ S̃ij ∂ S̃ij = γc A ρ (xi ) −2 Ãij ρ(xj ) + = δip δjq + δiq δjp , = δij δkj (107) ∂θkA (FA )2 j=1 + ∂θkS ∂ξk + ∂ Ãij + δip ρ(xq ) − δiq ρ(xp ) . + = δip δjq − δiq δjp , (108) + ∂θkA (111) +where k corresponds to the index pair (p, q) with p > q, as +defined in Eq. (98). The full presynaptic terms are then: Initialization. To ensure the stability of the system, we + initialize our parameters suchhthat the + i variance of dynam- + dyn + • For the diagonal parameters ξm : ical parameters scales as Var Jij ∝ 1/Ndyn . This is a + conservative choice for the layered architectures used in our + dyn + ∂Fi + + ξm X + N experiments, where many entries of Jij are zero. + = γcS ρ′ (xi ) − S̃ij ρ(xj ) + ∂ξm (FS )2 j=1 In practice, we initialize the parameter vectors θS , θA , and + (109) + ξ with identical variances σ 2 . For large Ndyn , the expected + + δim ρ(xm ) . Frobenius norms approximate to E[FS,A ] ≈ Ndyn σ. Conse- + quently, the normalization coefficients become: + p + 2 + • For the off-diagonal symmetric parameters θkS (where 1 − rstr rstr + cS ≈ , cA ≈ . (112) + p > q): Ndyn σ Ndyn σ + + N Since the symmetric and antisymmetric components are sta- + θkS X + + ∂Fi ′ + = γc S ρ (x i ) −2 S̃ij ρ(xj ) tistically independent, the variance of the weights is derived + ∂θkS (FS )2 j=1 + as follows: + + δip ρ(xq ) + δiq ρ(xp ) . + • Diagonal elements (i = j): + (110) + 2 + h i 1 − rstr + Var Jiidyn = γ 2 c2S σ 2 ≈ γ 2 2 . (113) + • For the off-diagonal antisymmetric parameters θkA Ndyn + + 19 + Equilibrium Propagation for Non-Conservative Systems + + • Off-diagonal elements (i ̸= j): a zero-cost baseline (perfect prediction) during learning. + Specifically, for each method and value of rstr , we calcu- + h + dyn + i γ2 late the cumulative loss by summing the batch-averaged + = γ 2 c2S + c2A σ 2 ≈ 2 , + + Var Jij (114) + Ndyn costs of the first 5 epochs (out of 30, to avoid saturation + h i effects), and reporting the mean and standard deviation over + dyn +To satisfy Var Jij ∝ 1/Ndyn , we set: 10 independent training runs. Mathematically, for each run: + + p + γ= Ndyn (115) + 5 NX + X batches X C(x0 , u) +Note that by random matrix theory, diagonal elements do Cumul. Loss = , + |Bk | +not affect stability in the large Ndyn limit. epoch=1 k=1 (x0 ,u)∈Bk + (120) +Potential Simplification. Although the parameterization where Bk represents the k-th batch, and |Bk | denotes the +above is fully general, a simpler construction is possible number of examples in the batch. The parameters are up- +by removing self-connections (ξ = 0) and enforcing identi- dated after each batch step; consequently, the free equilib- +cal parameterization for the symmetric and antisymmetric rium x0 is inferred using the updated parameters and the +components, i.e., θS = θA = θ. The matrix elements then current example u. +become: + + S̃ij = (1 − δij )θk(max(i,j),min(i,j)) , (116) + Ãij = ϵij θk(max(i,j),min(i,j)) . (117) + +In this case, the Frobenius norms are equal (FS = FA ), and +we can omit the explicit normalization: + q + dyn 2 S̃ + r à . + Jij = 1 − rstr ij str ij (118) + +For a parameter θk corresponding to indices (p, q) with +p > q, the presynaptic term is given by: + q + ∂Fi + = ρ′ (xi ) 2 +r + 1 − rstr str δip ρ(xq ) + ∂θk + q (119) + + 2 −r + 1 − rstr δ ρ(x + str iq p . + ) + +While this parameterization works in simulations and keeps Figure 5. Cumulative loss as defined by (120) over the first 5 +the number of parameters constant for all rstr , it constrains epochs of training, for different asymmetry ratios rstr . We compare +the asymmetry to be “homogeneous”, by which we mean VF (orange) and AsymEP (blue), under two training regimes: +that the asymmetry ratio is identical for every pair of neu- training only J in (dashed) or all parameters (solid). +rons; hence, the network cannot learn to be symmetric in one +region and antisymmetric in another. Therefore, we choose +to explore the more general case of (38) in our experiments. In Fig 5, we observe that learning slows down for both al- + gorithms when rstr ≳ 0.6. This behavior likely results from +G.2.2. S UPPLEMENTARY N UMERICAL R ESULTS the increased difficulty of reaching a stationary state as the + dynamics become strongly asymmetric. With a fixed num- +To complement the results of Fig 3, we analyze the training + ber of inference steps, incomplete convergence degrades the +efficiency as a function of the asymmetry ratio rstr and in- + accuracy of the gradient estimates, thereby slowing down +vestigate the robustness of VF by monitoring the Jacobian + the learning. Fig 5 shows that while VF can eventually +asymmetry. + achieve competitive accuracy, it is consistently slower than + AsymEP as soon as asymmetry is introduced. +Training efficiency. We first study the training efficiency +of the two algorithms as a function of the asymmetry ra- +tio rstr . Inspired by the related concept in (Cesa-Bianchi +&Lugosi, 2006), we define the cumulative loss as the accu- Jacobian asymmetry. We next examine how the struc- +mulated difference between the free equilibrium cost and tural asymmetry rstr is reflected in the Jacobian of the dy- + + 20 + Equilibrium Propagation for Non-Conservative Systems + +namics (35), given by: + ∂Fi dyn ′ + = (1 − δij )ρ′ (xi )Jij ρ (xj ) + ∂sj + h i + + δij ρ′ (xi )(Jiidyn ρ′ (xi )) + ρ′′ (xi )bi − 1 . + (121) + +In our layered architecture, the self-connections are zero +(Jiidyn = 0). For the following analysis, we neglect all diag- +onal terms in the Jacobian (including external inputs and +potential), since they do not contribute to the antisymmetric +correction (20) and thus to the discrepancy between the per- +formance of VF and AsymEP. Consequently, we define the +following asymmetry ratio based solely on the off-diagonal +Jacobian JF,off : Figure 6. Asymmetry ratio of the Jacobian rjac defined in equation + ⊤ + (122) after training for different asymmetry ratios rstr . We compare + ∥JF,off − JF,off ∥F VF (orange) and AsymEP (blue), under two training regimes: + rjac = , (122) training only J in (dashed) or all parameters (solid). + ∥JF,off ∥F + +The results are presented in Fig 6. For each trained model Consequently, local stability requires the largest real eigen- +and ratio rstr , we compute rjac averaged over the stationary value of the effective weight matrix to be strictly less than 1. +states of the first batch (64 images) across 10 independent Assuming weights are initialized independently with vari- +runs. We observe that when structural asymmetry is strong ance σ 2 , Girko’s circular law dictates that the eigenvalues +and all parameters are trained, VF partially compensates for of√an asymmetric matrix uniformly populate a disk of radius +the asymmetry by adjusting the neuronal states. This can be σ n in the complex plane. In contrast, imposing symmetry +understood by rewriting the ratio as: forces the eigenvalues √ onto the real line, broadening the + spectral radius to 2σ n according to Wigner’s semicircle + dyn dyn ⊤ + ρ′ (xi ) Jij − (Jji ) ρ′ (xj ) law. As a result, asymmetric networks can stably accommo- + F + rjac = . (123) date larger variance in the weight initializations than their + dyn ′ + ρ′ (xi )Jij ρ (xj ) symmetric counterparts. + F + +Compared to the structural asymmetry ratio in Eq. (37), Asymmetry nevertheless introduces imaginary eigenvalues +a value of rjac < rstr indicates that the neuronal states ef- and, consequently, damped oscillations. To study this effect +fectively dampen the structural asymmetry, rendering the experimentally in a controlled setting, we constrain the input +dynamics more symmetric. This symmetrization of the Ja- projections J in . In the experiments of the main text, fixing +cobian appears without imposing an additional symmetriza- the structural asymmetry ratio rstr still allowed AsymEP +tion penalty and could be enhanced using the method of to reduce oscillations by aligning and increasing the input +(Laborieux &Zenke, 2022). This mechanism likely explains projections J in , thereby adding stabilizing diagonal contri- +the superior performance of ‘All (VF)’ compared to ‘J in butions to the Jacobian. To isolate the network’s ability to +(VF)’ in Fig 3, as the former is able to use the additional suppress oscillations independently of the magnitude of the +degrees of freedom to reduce the effective asymmetry at input drive, we further constrain the relative scale of J in and +high rstr . J dyn by imposing + ∥J in ∥F ∥J in ∥F +G.3. Stability analysis with Fixed Asymmetry Ratio & rin = = , (124) + ∥J dyn ∥F γ + Constrained Inputs Projection + where ∥J dyn ∥F = γ following Eq. (101). Defining unscaled +A complete stability analysis of the non-conservative dy- input projections J˜in , we set +namics trainable with AsymEP is beyond the scope of this +work. Nevertheless, for the class of continuous Hopfield J˜in + J in = rin γ (125) +networks considered here, standard arguments from random ∥J˜in ∥F +matrix theory suggest that asymmetry inherently improves +asymptotic stability. G.3.1. L EARNING RULES +In the dynamics defined by Eq. (91), the linear leak term Reusing the notations of the previous section, we write +−xi shifts the spectrum of the system’s Jacobian by −1. Jilin = γcin J˜ilin with the normalization cin = rin /Fin , where + + 21 + Equilibrium Propagation for Non-Conservative Systems + +Fin = ∥J˜in ∥F . Applying the chain rule yields: + " # + ∂Fi ′ J˜kl + in X + ˜in + = γcin ρ (xi ) δik ul − 2 J um . (126) + ∂ J˜kl + in Fin m im + +And for γ we have: + + ∂Fi 1 + = (Fi + xi ). (127) + ∂γ γ + +G.3.2. S UPPLEMENTARY N UMERICAL R ESULTS + Figure 7. Comparison of AsymEP and VF on a feedforward net- +Table 5 reports a worst-case control experiment where the + work. Test accuracy on MNIST is shown as a function of training +structural asymmetry is fixed at rstr = 0.7 while the input epochs for a single-hidden-layer network with 20 neurons. Curves +scale ratio rin is varied. The experiment uses an all-to-all report the mean and standard deviation over 10 runs. Best accura- +architecture on MNIST (excluding direct input-to-output cies are 92.7% ± 0.5% (AsymEP) and 64.3% ± 2.0% (VF). +connections). The output variance during extended infer- +ence (steps 30-50) confirms that the system successfully +learns to suppress oscillations even when rin is severely re- G.5. Advantages of Non-Conservatives Dynamics +stricted. Any small residual oscillations can be mitigated by In Section 5.4, we compare three (non-)conservative dynam- +time-averaging over the inference steps. ics under varying constraints. To further evaluate learning +Finally, rin can be interpreted as a measure of the external speed, Table 6 reports network performance after a sin- +signal magnitude relative to the internal recurrent dynamics. gle epoch. These results confirm our earlier observation: +These results indicate that the system remains capable of AsymEP learns faster than VF. +learning and stabilizing even under a low external input +drive. Even when the input projection ∥J in ∥F is 100 times G.6. Feedforward CIFAR-10 Experiments +smaller than the recurrent connections ∥J dyn ∥F , the network This appendix details the architecture and hyperparameters +still achieves 36.34 ± 6.25% accuracy, which is well above of the deep feedforward experiments comparing backprop- +chance level (∼ 10%). agation (BP), VF, AsymEP and Dyadic EP on CIFAR-10 + (see subsection 5.5) +G.4. Feedforward Network +G.4.1. L EARNING RULES Architecture. We use a nine-layer convolutional network + (denoted CNN9). The first eight layers are convolutional +For clarity, we write the learning rules for VF and AsymEP with 3 × 3 kernels and zero-padding; spatial downsampling +in a feedforward architecture with one hidden layer using is performed by strided convolutions (stride 2 on layers 2, 4, +the notation of Section 5.3. For the input weights connecting 6, 8 and stride 1 otherwise), so no pooling is used. The chan- +to the hidden layer, we get the usual formula: nel widths follow the sequence 3 → 64 → 64 → 128 → + 128 → 256 → 256 → 512 → 512, reducing the spatial + 1 h +β −β 0 + i + resolution from 32 × 32 to 2 × 2. The last layer is a fully + in + ∆Jik ∝ (hi − hi )ρ′ (hi )uk , (128) + 2β connected readout mapping the 512 × 2 × 2 feature map + to the 10 class logits. All hidden units use a ReLU non- +while for the feedforward weights connecting the hidden to linearity. +the output layer, we get: p Weights are initialized with the Kaiming scheme + (σ = 2/fan-in) and biases at zero. + 1 h +β 0 + i + ∆(Wh→o )ji ∝ (oj − o−β ′ 0 + j )ρ (oj )ρ(hi ) . (129) Training setup. All methods are trained for 40 epochs + 2β with batch size 64 and repeated over 5 seeds. Inputs are + normalized per channel and augmented with random 32 × +Note that EP is not applicable in this case. + 32 crops (padding 4), random horizontal flips and Cutout + (one 8 × 8 patch). Parameters are updated with SGD with +G.4.2. S UPPLEMENTARY N UMERICAL R ESULTS + momentum 0.9, weight decay 5 × 10−4 and gradient-norm +In addition to the final accuracy reported in Sec. 5.3, we clipping at 1, under a cosine learning-rate schedule decaying +show in Fig. 7 the evolution of the accuracy over 20 epochs from 3.5 × 10−2 to 2 × 10−4 . Test accuracy is reported +for AsymEP and VF. on an exponential moving average of the weights (decay + + 22 + Equilibrium Propagation for Non-Conservative Systems + +Table 5. Output variance and final test accuracy on MNIST (%) across different values of rin with rstr = 0.7. (mean ± std over 10 runs) +(500 hiddens, all-to-all, no input-output). + + Output variance Test Acc. (%) + rin Untrained Epoch 80 Epoch 80 + 0.01 (3.38 ± 0.90) × 10−4 (5.22 ± 2.34) × 10−5 36.34 ± 6.25 + 0.10 (2.33 ± 0.48) × 10−4 (1.39 ± 0.17) × 10−4 90.54 ± 0.19 + 0.50 (1.34 ± 0.32) × 10−5 (1.06 ± 0.25) × 10−6 94.96 ± 0.10 + 1.00 (6.27 ± 1.24) × 10−7 (1.75 ± 0.50) × 10−8 96.30 ± 0.09 + + +Table 6. Test accuracy on Fashion-MNIST (%) at Epoch 1 (mean +± std 10 runs). The table compares three classes of network +dynamics: Continuous Hopfield (CH), Predictive Coding (PC), +and Standard dynamics. Each is evaluated under three connec- + ⊤ +tivity structures: Asymmetric (Asym, Bk ̸= Wk+1 ), Symmet- + ⊤ +ric/conservative (Sym, Bk = Wk+1 ), and Feedforward (Feedfor, +Bk = 0). + + EP AsymEP VF + Asym - 74.91 ± 0.45 48.98 ± 4.09 + CH Feedfor - 74.36 ± 0.29 48.84 ± 3.46 + Sym 74.57 ± 0.43 - - + Asym - 77.83 ± 0.47 57.75 ± 3.37 + PC + Sym 76.23 ± 0.39 - - + Asym - 76.87 ± 0.51 61.50 ± 4.06 + Standard + Feedfor - 77.92 ± 0.51 63.98 ± 0.73 + + + +0.9995). The targets are smoothed (ε = 0.1), which for +the EP methods amounts to nudging toward the smoothed +one-hot vector. + +Relaxation hyperparameters. The four methods differ +only in the gradient estimator: BP uses automatic differ- +entiation, while the EP-based methods contrast stationary +states of the corresponding relaxation dynamics. VF uses +an integration step η = 1.0, nudging β = 0.1, and at most +K = 1000 relaxation steps with an early-stopping threshold +of 9 × 10−6 on the mean state update. Dyadic EP uses +the same settings except for a nudging strength β = 0.1. +AsymEP uses a smaller step η = 0.5, nudging β = 0.1, +and up to K = 250 relaxation steps with a threshold of +1 × 10−4 . + + + + + 23 +
\ No newline at end of file diff --git a/ep_run/solver_wall.py b/ep_run/solver_wall.py new file mode 100644 index 0000000..ee9bf73 --- /dev/null +++ b/ep_run/solver_wall.py @@ -0,0 +1,61 @@ +"""Wall-breaking probe. The EP ceiling I measured comes from: rich (thick) block is +non-contractive -> EP needs heavy damping c to converge the free phase -> damping suppresses +the very expressivity that made the block good. ESCAPE ROUTE: get convergence from a SOLVER +(Anderson accel, DEQ-style) instead of from damping. Decisive question: for the THICK block, +at LOW damping (expressivity intact), can Anderson converge where plain relaxation cannot? +If yes -> the wall is a solver problem, not fundamental. If no -> the rich block has no fixed +point to find and the ceiling is intrinsic to the EP/fixed-point requirement.""" +import math, sys, torch +from lt_ep_train import EQBlock, get_batch +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +B, T, C, H = 16, 64, 128, 4 +eps = 0.05 + + +def gmap(blk, z, xin): # relaxation map; fixed point = equilibrium + with torch.no_grad(): + return z + eps * blk.force(z, xin).detach() + + +def plain(blk, z0, xin, steps=200): + z = z0.clone() + for _ in range(steps): + z = gmap(blk, z, xin) + return ((gmap(blk, z, xin) - z).norm() / (z.norm() + 1e-9)).item() + + +def anderson(blk, z0, xin, m=6, max_iter=150, tol=1e-6, lam=1e-4): + Bs, d = z0.shape[0], z0[0].numel() + X = torch.zeros(Bs, m, d, device=dev); Fb = torch.zeros(Bs, m, d, device=dev) + X[:, 0] = z0.reshape(Bs, d); Fb[:, 0] = gmap(blk, z0, xin).reshape(Bs, d) + X[:, 1] = Fb[:, 0]; Fb[:, 1] = gmap(blk, X[:, 1].view_as(z0), xin).reshape(Bs, d) + Hm = torch.zeros(Bs, m + 1, m + 1, device=dev); Hm[:, 0, 1:] = 1; Hm[:, 1:, 0] = 1 + yv = torch.zeros(Bs, m + 1, 1, device=dev); yv[:, 0] = 1 + r, k = 1.0, 2 + for k in range(2, max_iter): + n = min(k, m) + Gm = Fb[:, :n] - X[:, :n] + Hm[:, 1:n + 1, 1:n + 1] = torch.bmm(Gm, Gm.transpose(1, 2)) + lam * torch.eye(n, device=dev)[None] + alpha = torch.linalg.solve(Hm[:, :n + 1, :n + 1], yv[:, :n + 1])[:, 1:n + 1, 0] + X[:, k % m] = torch.bmm(alpha[:, None], Fb[:, :n])[:, 0] + Fb[:, k % m] = gmap(blk, X[:, k % m].view_as(z0), xin).reshape(Bs, d) + r = ((Fb[:, k % m] - X[:, k % m]).norm() / (Fb[:, k % m].norm() + 1e-9)).item() + if r < tol or not math.isfinite(r): + break + return r, k + 1 + + +for mode in ['real', 'thick']: + torch.manual_seed(0) + blk = EQBlock(C, H, 256, T, attn_mode=mode) + idx, y = get_batch('train', B, T) + xin = blk.embed(idx).detach() + print(f"\n=== attn_mode={mode} === free-phase convergence: plain relax(200) vs Anderson, eps={eps}") + print(f"{'damp c':>7} {'plain_res':>11} {'anderson_res':>13} {'and_iters':>10}") + for c in [0.0, 0.25, 0.5, 1.0, 2.0]: + blk.c = c + pr = plain(blk, xin.clone(), xin) + ar, ak = anderson(blk, xin.clone(), xin) + flag = ' <- solver converges where plain fails' if (ar < 1e-4 and pr > 1e-2) else '' + print(f"{c:>7.2f} {pr:>11.2e} {ar:>13.2e} {ak:>10d}{flag}") diff --git a/ep_run/spec_bifurcation.py b/ep_run/spec_bifurcation.py new file mode 100644 index 0000000..aca7629 --- /dev/null +++ b/ep_run/spec_bifurcation.py @@ -0,0 +1,38 @@ +import torch, glob, re, math, pickle +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD = Path('data/tinystories_bpe') +L.vocab = pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; eps=0.1; B=8; T=256; N=800 +torch.manual_seed(1234) # FIXED batch across all ckpts +idx, y = L.get_batch('val', B, T) +idx = idx.to(dev) if hasattr(idx,'to') else idx +def measure(ckpt): + blk = EQBlock(512,16,256,256, s=1.0, c=1.0, attn_mode='thick'); blk.qknorm=True + ck = torch.load(ckpt, map_location=dev) + with torch.no_grad(): + for p,w in zip(blk.allp, ck['allp']): p.copy_(w.to(dev)) + xin = blk.embed(idx).detach(); z = xin.clone(); ress=[] + for t in range(N): + z2 = z + eps*blk.force(z, xin).detach() + r = (z2-z).norm().item()/(z.norm().item()+1e-9); ress.append(r); z=z2 + if (not math.isfinite(r)) or r>1e2: break + win=[ress[i] for i in range(len(ress)) if 1e-6<ress[i]<1e-1] or ress[-200:] + rats=[win[i+1]/win[i] for i in range(len(win)-1) if win[i]>0] + rho = math.exp(sum(math.log(x) for x in rats)/len(rats)) if rats else float('nan') + return rho, ress[-1], len(ress) +valmap={} +for l in open('runs/ep_redx.log'): + if l.startswith('step'): + m=re.search(r'step\s+(\d+)/.*val CE ([\d.]+)', l) + if m: valmap[int(m.group(1))]=m.group(2) +print("=== BIFURCATION PROBE — free-phase contraction ratio rho (fixed batch, 800 relax steps) ===") +print("rho<1 contractive | rho->1 marginal | rho>=1 divergent ; final_res = floor reached") +r,fr,n = measure('runs/bptt_final.pt'); print(f"REF BPTT(val~1.83): rho={r:.4f} final_res={fr:.2e} steps={n}") +print("redx approach to blowup:") +for ck in sorted(glob.glob('runs/redx_traj/s*.pt'), key=lambda p:int(re.search(r's(\d+)',p).group(1))): + step=int(re.search(r's(\d+)',ck).group(1)) + try: r,fr,n = measure(ck); print(f" step {step:5d} (val {valmap.get(step,'?')}): rho={r:.4f} final_res={fr:.2e} steps={n}") + except Exception as e: print(f" step {step:5d}: ERR {repr(e)[:80]}") +print("=== DONE ===") diff --git a/ep_run/spec_check.py b/ep_run/spec_check.py new file mode 100644 index 0000000..be633d5 --- /dev/null +++ b/ep_run/spec_check.py @@ -0,0 +1,15 @@ +import torch, pickle +from pathlib import Path +import lt_ep_train as L +L.DD = Path('data/tinystories_bpe'); L.vocab = pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +blk = L.EQBlock(512,16,256,256, s=1.0, c=1.0, attn_mode='thick'); blk.qknorm=True +ck = torch.load('runs/bptt_spec_snapshot.pt', map_location='cpu') +with torch.no_grad(): + for p,w in zip(blk.allp, ck['allp']): p.copy_(w.to('cpu').float()) +print(f"# BPTT ckpt step={ck.get('step')} best_val={ck.get('best')}") +print(f"# specnorm cap = 0.9") +for name in ['WQ','WK','WV','WO','fc','pj']: + W = getattr(blk, name).float() + s = torch.linalg.svdvals(W)[0].item() + flag = " <-- EXCEEDS cap" if s > 0.9 else "" + print(f"{name}{tuple(W.shape)}: sigma_max={s:.3f}{flag}") diff --git a/ep_run/spec_rho_vs_c.py b/ep_run/spec_rho_vs_c.py new file mode 100644 index 0000000..47661e0 --- /dev/null +++ b/ep_run/spec_rho_vs_c.py @@ -0,0 +1,29 @@ +import torch, math, pickle +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; eps=0.1; B=8; T=256; N=800 +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T) +idx=idx.to(dev) if hasattr(idx,'to') else idx +def measure_c(ckpt,c): + blk=EQBlock(512,16,256,256,s=1.0,c=c,attn_mode='thick'); blk.qknorm=True + ck=torch.load(ckpt,map_location=dev) + with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + xin=blk.embed(idx).detach(); z=xin.clone(); ress=[] + for t in range(N): + z2=z+eps*blk.force(z,xin).detach() + r=(z2-z).norm().item()/(z.norm().item()+1e-9); ress.append(r); z=z2 + if not math.isfinite(r) or r>1e2: break + win=[ress[i] for i in range(len(ress)) if 1e-6<ress[i]<1e-1] or ress[-200:] + rats=[win[i+1]/win[i] for i in range(len(win)-1) if win[i]>0] + rho=math.exp(sum(math.log(x) for x in rats)/len(rats)) if rats else float('nan') + return rho, ress[-1] +print("=== rho vs damping c — does more c pull the operator off the rho=1 threshold? ===") +print("(weights trained at c=1; this is eval-time c — a margin indicator, not the trained answer)") +for ck,lab in [('runs/redx_traj/s3200.pt','redx-s3200 (val2.74, marginal)'),('runs/bptt_final.pt','BPTT (1.83)')]: + for c in [1.0,1.5,2.0,3.0,4.0]: + try: rho,fr=measure_c(ck,c); print(f" {lab} c={c}: rho={rho:.4f} final_res={fr:.2e}") + except Exception as e: print(f" {lab} c={c}: ERR {repr(e)[:60]}") +print("=== DONE ===") diff --git a/ep_run/speed_probe.py b/ep_run/speed_probe.py new file mode 100644 index 0000000..0f0e3ba --- /dev/null +++ b/ep_run/speed_probe.py @@ -0,0 +1,63 @@ +"""Speed-package probe for the 50M demo. Run on a free GPU (A6000 preferred). +(1) torch.compile speedup on the relax loop (exact math, free speed). +(2) bf16 force evals at r=0.2 with the TRACKING estimator: does the contrast survive low + precision when the nudge is large and the common mode cancels? (tf32 died at r=0.02+frozen; + this is the missing measurement that decides the 50M/1B cost sheet.) +Outputs: it/s-equivalents + gradient cosine vs fp32 reference. +""" +import time, torch +import lt_ep_train as M +from pathlib import Path +import pickle +M.DD = Path('/tmp/lt_ep/data/tinystories') +M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size'] +from lt_ep_train import EQBlock, get_batch, bptt_step, relax +from holo_ep import holo_a_track, holo_a_select2 + +dev = 'cuda' +torch.manual_seed(0) +B, T, C, H = 8, 256, 256, 8 +blk = EQBlock(C, H, 256, T, attn_mode='thick') +ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt') +for p, w in zip(blk.allp, ck['allp']): + with torch.no_grad(): + p.copy_(w.to(dev)) +idx, y = get_batch('train', B, T) +xin = blk.embed(idx).detach() + +# --- (1) compile speedup on relax --- +t0 = time.time(); zs = relax(blk, xin.clone(), xin, 300, 0.1); torch.cuda.synchronize() +base = time.time() - t0 +cfun = torch.compile(lambda z: z + 0.1 * blk.force(z, xin).detach(), mode='max-autotune-no-cudagraphs') +z = xin.clone() +for _ in range(10): + z = cfun(z) # warmup/compile +torch.cuda.synchronize() +t0 = time.time() +z = xin.clone() +for _ in range(300): + z = cfun(z) +torch.cuda.synchronize() +comp = time.time() - t0 +print(f"[compile] relax300: eager {base:.2f}s -> compiled {comp:.2f}s ({base/comp:.2f}x)", flush=True) + +# --- (2) bf16 @ r=0.2 + tracking --- +aref, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1) +def cos(a, b): + return (a.flatten() @ b.flatten() / (a.norm() * b.norm() + 1e-12)).item() +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +atf, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1) +print(f"[tf32 + track + r=0.2] cos vs fp32 = {cos(atf, aref):.3f}", flush=True) +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False +with torch.autocast('cuda', dtype=torch.bfloat16): + abf, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1) +abf = abf.float() +print(f"[bf16 + track + r=0.2] cos vs fp32 = {cos(abf, aref):.3f}", flush=True) +# also the old failure case for reference +a_old_ref, _ = holo_a_select2(blk, zs, xin, y, 0.02, 120, 0.1) +torch.backends.cuda.matmul.allow_tf32 = True +a_old_tf, _ = holo_a_select2(blk, zs, xin, y, 0.02, 120, 0.1) +torch.backends.cuda.matmul.allow_tf32 = False +print(f"[tf32 + frozen + r=0.02 (known-dead control)] cos = {cos(a_old_tf, a_old_ref):.3f}", flush=True) diff --git a/ep_run/stiefel_feedback.py b/ep_run/stiefel_feedback.py new file mode 100644 index 0000000..a5c86ca --- /dev/null +++ b/ep_run/stiefel_feedback.py @@ -0,0 +1,126 @@ +"""Factored feedback subspace: fixed compressor C + per-layer Stiefel expander U_l. + +δ_l = α_l · c_t @ U_l^T, where c_t = e_L @ C^T + +C ∈ R^{r×V}: fixed row-orthonormal compressor (CC^T = I_r) +U_l ∈ St(d_l, r): per-layer learnable orthonormal expander +α_l > 0: per-layer scalar gain + +U_l updated via Riemannian gradient on Stiefel with EMA + QR retraction. +α_l updated via correlation-based least-squares, not norm-ratio. +""" +import torch +import torch.nn as nn + + +@torch.no_grad() +def init_row_orthonormal_C(vocab_size: int, rank: int, device=None, dtype=torch.float32): + """C ∈ R^{r×V} with CC^T = I_r.""" + g = torch.randn(vocab_size, rank, device=device, dtype=dtype) + q, _ = torch.linalg.qr(g, mode="reduced") # (V, r) + return q.T.contiguous() # (r, V) + + +class StiefelFeedbackLayer(nn.Module): + """Per-layer factored feedback: δ = α · c @ U^T where U ∈ St(d, r).""" + + def __init__(self, d: int, r: int): + super().__init__() + # U on Stiefel: (d, r), orthonormal columns + U_init = torch.linalg.qr(torch.randn(d, r), mode="reduced")[0] + self.register_buffer("U", U_init) + self.register_buffer("alpha", torch.tensor(0.1)) + self.register_buffer("ema_G", torch.zeros(d, r)) + + def compute_delta(self, c: torch.Tensor) -> torch.Tensor: + """c: (B, T, r) → δ: (B, T, d)""" + return self.alpha * (c @ self.U.T) + + @torch.no_grad() + def update(self, g_hat: torch.Tensor, c: torch.Tensor, + eta_B: float = 3e-5, tau: float = 0.99, + beta_alpha: float = 0.01, eps: float = 1e-8, + alpha_min: float = 1e-4, alpha_max: float = 10.0, + max_step_frob: float = 1.0, frozen: bool = False): + """Update U and alpha from local signal g_hat and compressed error c. + + g_hat: (B, T, d) — local proxy signal (e.g. reconstruction error) + c: (B, T, r) — compressed global error + frozen: if True, only accumulate ema_G, don't update U or alpha + """ + B, T, d = g_hat.shape + r = c.shape[-1] + N = B * T + + g_flat = g_hat.reshape(N, d) + c_flat = c.reshape(N, r) + + # Cross-covariance G = (1/N) g_hat^T @ c + G = (g_flat.T @ c_flat) / max(N, 1) # (d, r) + + # EMA + self.ema_G.mul_(tau).add_(G, alpha=1.0 - tau) + + if frozen: + return {"G": G, "alpha": self.alpha.clone(), "frozen": True} + + # Tangent projection on Stiefel + UtG = self.U.T @ self.ema_G # (r, r) + sym = 0.5 * (UtG + UtG.T) + Delta = self.ema_G - self.U @ sym # (d, r) + + # Step clipping + delta_norm = torch.linalg.norm(Delta, ord="fro") + if max_step_frob is not None and delta_norm > max_step_frob: + Delta = Delta * (max_step_frob / (delta_norm + eps)) + + # Riemannian step + QR retraction + U_tilde = self.U + (eta_B * self.alpha) * Delta + Q, R = torch.linalg.qr(U_tilde, mode="reduced") + # Sign fix: make diag(R) positive + s = torch.sign(torch.diag(R)) + s = torch.where(s == 0, torch.ones_like(s), s) + self.U.copy_(Q * s.unsqueeze(0)) + + # Correlation-based alpha update: α* = <G, U> / (mean ||c||^2 + eps) + c2_mean = c_flat.square().sum() / max(N, 1) + alpha_star = (self.U * G).sum() / (c2_mean + eps) + alpha_star = alpha_star.clamp(min=alpha_min, max=alpha_max) + self.alpha.mul_(1.0 - beta_alpha).add_(alpha_star, alpha=beta_alpha) + + return { + "G": G, + "Delta_frob": delta_norm.item(), + "alpha_star": alpha_star.item(), + "alpha": self.alpha.item(), + "rho": (G * self.U).sum().item() / (torch.linalg.norm(G, ord="fro").item() * (r ** 0.5) + eps), + } + + +class StiefelFeedbackSystem(nn.Module): + """Full feedback system: global C + per-layer StiefelFeedbackLayer.""" + + def __init__(self, vocab_size: int, layer_dims: list[int], rank: int = 128): + super().__init__() + self.rank = min(rank, vocab_size) # can't compress to more dims than vocab + self.register_buffer("C", init_row_orthonormal_C(vocab_size, self.rank)) + self.layers = nn.ModuleList([ + StiefelFeedbackLayer(d, self.rank) for d in layer_dims + ]) + + def compress_error(self, e_L: torch.Tensor) -> torch.Tensor: + """e_L: (B, T, V) → c: (B, T, r)""" + return e_L @ self.C.T # (B, T, r) + + def compute_deltas(self, c: torch.Tensor) -> list[torch.Tensor]: + """Compute per-layer feedback δ_l from compressed error c.""" + return [layer.compute_delta(c) for layer in self.layers] + + def update_all(self, g_hats: list[torch.Tensor], c: torch.Tensor, + frozen: bool = False, **kwargs) -> list[dict]: + """Update all layers' U and alpha.""" + diagnostics = [] + for layer, g_hat in zip(self.layers, g_hats): + diag = layer.update(g_hat, c, frozen=frozen, **kwargs) + diagnostics.append(diag) + return diagnostics diff --git a/ep_run/t2fix_freezer.py b/ep_run/t2fix_freezer.py new file mode 100644 index 0000000..d5d090f --- /dev/null +++ b/ep_run/t2fix_freezer.py @@ -0,0 +1,21 @@ +import time, os, re, shutil +os.chdir("/home/yurenh2/ept/ep_run") +LOG, CK = "runs/ep_t2fix.log", "runs/ep_t2fix.pt" +os.makedirs("runs/t2fix_traj", exist_ok=True) +last_mt=0; t0=time.time() +while time.time()-t0 < 24*3600: + time.sleep(30) + try: ls=[l for l in open(LOG) if l.startswith("step")] + except Exception: continue + if not ls: continue + m=re.search(r"step\s+(\d+)/.*res=([\d.eE+-]+)", ls[-1]) + if not m: continue + step=int(m.group(1)); res=float(m.group(2)) + try: mt=os.path.getmtime(CK) + except Exception: continue + if mt!=last_mt and os.path.getsize(CK)>1e6: + last_mt=mt + try: shutil.copy2(CK, f"runs/t2fix_traj/s{step}.pt"); print(f"froze s{step} res={res:.2e}",flush=True) + except Exception: pass + if res>0.25: print(f"ep_t2fix DIVERGED step {step}",flush=True); break +print("t2fix freezer done") diff --git a/ep_run/t2fix_rho_prober.py b/ep_run/t2fix_rho_prober.py new file mode 100644 index 0000000..ad9d1e8 --- /dev/null +++ b/ep_run/t2fix_rho_prober.py @@ -0,0 +1,52 @@ +import time, os, re, glob, math, pickle, subprocess +from pathlib import Path +import torch +import lt_ep_train as L +from lt_ep_train import EQBlock +os.chdir("/home/yurenh2/ept/ep_run") +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; eps=0.1; B=8; T=256; N=800; OUT="runs/t2fix_rho.log" +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +def measure(ckpt): + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True + ck=torch.load(ckpt,map_location=dev) + with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + xin=blk.embed(idx).detach(); z=xin.clone(); ress=[] + for t in range(N): + z2=z+eps*blk.force(z,xin).detach() + r=(z2-z).norm().item()/(z.norm().item()+1e-9); ress.append(r); z=z2 + if not math.isfinite(r) or r>1e2: break + win=[ress[i] for i in range(len(ress)) if 1e-6<ress[i]<1e-1] or ress[-200:] + rats=[win[i+1]/win[i] for i in range(len(win)-1) if win[i]>0] + return (math.exp(sum(math.log(x) for x in rats)/len(rats)) if rats else float('nan')), ress[-1] +def alive(): return subprocess.run(["pgrep","-f","ckpt runs/ep_t2fix.pt"],capture_output=True).returncode==0 +def valof(step): + try: + for l in open("runs/ep_t2fix.log"): + if l.startswith("step"): + m=re.search(r"step\s+(\d+)/.*val CE ([\d.]+)", l) + if m and int(m.group(1))==step: return m.group(2) + except Exception: pass + return "?" +open(OUT,"a").write("# ep_t2fix(t2sel=160) rho-tracking | refs: BPTT=0.982 ; redx val3.4->0.988, val2.74->0.998(blew)\n") +seen=set(); fired=None; t0=time.time() +while fired is None and time.time()-t0<24*3600: + time.sleep(60) + for ck in sorted(glob.glob("runs/t2fix_traj/s*.pt"), key=lambda p:int(re.search(r's(\d+)',p).group(1))): + step=int(re.search(r's(\d+)',ck).group(1)) + if step in seen: continue + seen.add(step) + try: rho,fr=measure(ck) + except Exception: rho,fr=float('nan'),float('nan') + v=valof(step); line=f"step {step} val {v}: rho={rho:.4f} final_res={fr:.2e}" + open(OUT,"a").write(line+"\n"); print(line,flush=True) + try: + vf=float(v) + if vf<3.3 and isinstance(rho,float) and not math.isnan(rho): + if rho<0.992: fired=f"ep_t2fix val {v}: rho={rho:.4f} STAYS LOW (vs redx ~0.998 here) -> better gradient holds contraction (FIX WORKING)"; break + if rho>0.997: fired=f"ep_t2fix val {v}: rho={rho:.4f} DRIFTED to threshold like redx -> gradient does NOT defend rho (your objection holds)"; break + except Exception: pass + if not alive(): fired=f"ep_t2fix exited; trajectory in {OUT}"; break +print("=== EP_T2FIX RHO VERDICT ==="); print(fired or "24h timeout") +for l in open(OUT): print(l.rstrip()) diff --git a/ep_run/test_aselect_deepdive.py b/ep_run/test_aselect_deepdive.py new file mode 100644 index 0000000..54d2776 --- /dev/null +++ b/ep_run/test_aselect_deepdive.py @@ -0,0 +1,323 @@ +#!/usr/bin/env python3 +"""Standalone EP a-select performance/correctness probes. + +Does not modify trainer files. It can run on CPU if CUDA is unavailable; CUDA timing is only +attempted when torch.cuda.is_available(). +""" +import argparse, math, os, time, traceback +import torch +import torch.nn.functional as F +import torch.func as tf +import lt_ep_train as LT +import holo_ep as H + + +def cosine(a, b): + a = a.detach().reshape(-1).float(); b = b.detach().reshape(-1).float() + return float((a @ b) / (a.norm() * b.norm() + 1e-30)) + + +def max_rel(a, b): + return float((a-b).abs().max() / (b.abs().max() + 1e-12)) + + +def ln_jvp(x, dx, gamma, beta=None, eps=1e-5): + # Matches PyTorch layer_norm over the last dim (biased variance, affine gamma). + mu = x.mean(dim=-1, keepdim=True) + xc = x - mu + inv = torch.rsqrt((xc * xc).mean(dim=-1, keepdim=True) + eps) + xhat = xc * inv + dmu = dx.mean(dim=-1, keepdim=True) + # mean(xhat * dx), not mean(xhat * (dx-dmu)); mean(xhat)==0. + proj = (xhat * dx).mean(dim=-1, keepdim=True) + dy = inv * (dx - dmu - xhat * proj) + return dy * gamma + + +def ln_vjp(x, gy, gamma, eps=1e-5): + mu = x.mean(dim=-1, keepdim=True) + xc = x - mu + inv = torch.rsqrt((xc * xc).mean(dim=-1, keepdim=True) + eps) + xhat = xc * inv + g = gy * gamma + return inv * (g - g.mean(dim=-1, keepdim=True) - xhat * (g * xhat).mean(dim=-1, keepdim=True)) + + +def rms_jvp(x, dx, eps=1e-6): + inv = torch.rsqrt((x*x).mean(dim=-1, keepdim=True) + eps) + return inv * dx - x * (inv ** 3) * (x * dx).mean(dim=-1, keepdim=True) + + +def rms_vjp(x, gy, eps=1e-6): + # RMSNorm Jacobian is symmetric. + inv = torch.rsqrt((x*x).mean(dim=-1, keepdim=True) + eps) + return inv * gy - x * (inv ** 3) * (x * gy).mean(dim=-1, keepdim=True) + + +def gelu_tanh_deriv(x): + # derivative of F.gelu(x, approximate='tanh') + k = 0.7978845608028654 + a = 0.044715 + u = k * (x + a * x * x * x) + t = torch.tanh(u) + return 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * k * (1.0 + 3.0 * a * x * x) + + +def manual_nc_jvp_vjp_thick(blk, z, vec): + """Explicit fp32 Jv and J^T vec for blk.nc_force(z), thick + qknorm path. + + This is deliberately written as plain ATen ops: no torch.func, no autograd. It assumes + blk.fnoise == 0 and attn_mode == 'thick'. It returns (Jv, JTv) for the same input vector. + """ + assert blk.attn_mode == 'thick' + assert getattr(blk, 'fnoise', 0.0) == 0.0 + B, T, C = z.shape + Hh, dh = blk.H, blk.dh + scale = 1.0 / math.sqrt(dh) + + # Base forward intermediates. + h1 = F.layer_norm(z, (C,), blk.ln1g, blk.ln1b) + h2 = F.layer_norm(z, (C,), blk.ln2g, blk.ln2b) + q0 = (h1 @ blk.WQ).view(B, T, Hh, dh).transpose(1, 2) + k0 = (h1 @ blk.WK).view(B, T, Hh, dh).transpose(1, 2) + vv = (h1 @ blk.WV).view(B, T, Hh, dh).transpose(1, 2) + if getattr(blk, 'qknorm', False): + q = q0 * torch.rsqrt(q0.pow(2).mean(-1, keepdim=True) + 1e-6) + k = k0 * torch.rsqrt(k0.pow(2).mean(-1, keepdim=True) + 1e-6) + else: + q, k = q0, k0 + logits = (q @ k.transpose(-2, -1)) * scale + p = torch.softmax(logits.masked_fill(~blk.cmask, float('-inf')), -1) + + u = h2 @ blk.fc + blk.fcb + gp = gelu_tanh_deriv(u) + + # JVP: attention branch. + dh1 = ln_jvp(z, vec, blk.ln1g) + dq0 = (dh1 @ blk.WQ).view(B, T, Hh, dh).transpose(1, 2) + dk0 = (dh1 @ blk.WK).view(B, T, Hh, dh).transpose(1, 2) + dvv = (dh1 @ blk.WV).view(B, T, Hh, dh).transpose(1, 2) + if getattr(blk, 'qknorm', False): + dq = rms_jvp(q0, dq0) + dk = rms_jvp(k0, dk0) + else: + dq, dk = dq0, dk0 + dlogits = (dq @ k.transpose(-2, -1) + q @ dk.transpose(-2, -1)) * scale + dp = p * (dlogits - (p * dlogits).sum(-1, keepdim=True)) + datt_heads = dp @ vv + p @ dvv + Jv_att = datt_heads.transpose(1, 2).reshape(B, T, C) @ blk.WO + + # JVP: FFN branch. + dh2 = ln_jvp(z, vec, blk.ln2g) + du = dh2 @ blk.fc + Jv_ff = (du * gp) @ blk.pj + Jv = Jv_att + Jv_ff + + # VJP: attention branch. + gout = vec + gh_att_heads = (gout @ blk.WO.t()).view(B, T, Hh, dh).transpose(1, 2) + gp_soft = gh_att_heads @ vv.transpose(-2, -1) + gv_heads = p.transpose(-2, -1) @ gh_att_heads + glogits = p * (gp_soft - (gp_soft * p).sum(-1, keepdim=True)) + gq = (glogits @ k) * scale + gk = (glogits.transpose(-2, -1) @ q) * scale + if getattr(blk, 'qknorm', False): + gq0 = rms_vjp(q0, gq) + gk0 = rms_vjp(k0, gk) + else: + gq0, gk0 = gq, gk + gh1 = (gq0.transpose(1, 2).reshape(B, T, C) @ blk.WQ.t() + + gk0.transpose(1, 2).reshape(B, T, C) @ blk.WK.t() + + gv_heads.transpose(1, 2).reshape(B, T, C) @ blk.WV.t()) + JTv_att = ln_vjp(z, gh1, blk.ln1g) + + # VJP: FFN branch. + gg = gout @ blk.pj.t() + gu = gg * gp + gh2 = gu @ blk.fc.t() + JTv_ff = ln_vjp(z, gh2, blk.ln2g) + JTv = JTv_att + JTv_ff + return Jv, JTv + + +def make_block(args, device): + if args.tiny: + # Tiny block for compiler/frontend feasibility tests. + blk = LT.EQBlock(32, 4, 64, 32, c=1.0, attn_mode='thick') + B = args.B or 2 + T = 32 + y_vocab = LT.vocab + idx = torch.randint(0, y_vocab, (B, T), device=device) + y = torch.randint(0, y_vocab, (B, T), device=device) + blk.qknorm = True; blk.track = True; blk.navg = 1; blk.li_avg = 0 + return blk, idx, y + blk = LT.EQBlock(512, 16, 256, 256, c=1.0, attn_mode='thick') + blk.qknorm = True; blk.track = True; blk.navg = 1; blk.li_avg = 0 + ck = torch.load(args.ckpt, map_location=device) + with torch.no_grad(): + for p, s in zip(blk.allp, ck['allp']): + p.copy_(s.to(device)) + B = args.B or 1 + idx, y = LT.get_batch('train', B, 256) + return blk, idx, y + + +def sync(device): + if device.type == 'cuda': + torch.cuda.synchronize(device) + + +def time_call(fn, device, repeat=3): + # one warmup + out = fn(); sync(device) + ts=[] + for _ in range(repeat): + t0=time.time(); out=fn(); sync(device); ts.append(time.time()-t0) + return min(ts), out + + +def make_tf_step(blk, zs, xin, y, r, eps): + B = zs.size(0) + X2 = torch.cat([xin, xin], 0) + y2 = torch.cat([y, y], 0) + sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) + def step(Z): + zbar = 0.5 * (Z[:B] + Z[B:]) + zb2 = torch.cat([zbar, zbar], 0) + f = H.rforce(blk, Z, X2) - sg * H.rgrad_ce(blk, Z, y2, denom=y.numel()) + v = (Z - zb2).contiguous() + fnc = lambda zz: blk.nc_force(zz) + _, Jv = tf.jvp(fnc, (zb2,), (v,)) + JTv = tf.vjp(fnc, zb2)[1](v)[0] + return Z + eps * (f - (Jv - JTv)) + return step + + +def make_manual_step(blk, zs, xin, y, r, eps): + B = zs.size(0) + X2 = torch.cat([xin, xin], 0) + y2 = torch.cat([y, y], 0) + sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) + def step(Z): + zbar = 0.5 * (Z[:B] + Z[B:]) + zb2 = torch.cat([zbar, zbar], 0) + f = H.rforce(blk, Z, X2) - sg * H.rgrad_ce(blk, Z, y2, denom=y.numel()) + v = (Z - zb2).contiguous() + Jv, JTv = manual_nc_jvp_vjp_thick(blk, zb2, v) + return Z + eps * (f - (Jv - JTv)) + return step + + +def run_loop_from_step(step, zs, r, T2, K=10): + B = zs.size(0) + Z = torch.cat([zs, zs], 0) + a_prev = a_best = None + inc_min = float('inf'); t_best = 0 + for t in range(1, T2+1): + Z = step(Z) + if t % K == 0 or t == T2: + a_t = (Z[B:] - Z[:B]) / (2*r) + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + a_prev = a_t + if a_best is None: + a_best = a_prev; t_best = T2 + return a_best.detach(), t_best + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--ckpt', default='/home/yurenh2/ept/ep_run/runs/ep_resreg_warm.pt') + ap.add_argument('--tiny', action='store_true') + ap.add_argument('--B', type=int, default=None) + ap.add_argument('--T1', type=int, default=2) + ap.add_argument('--T2', type=int, default=2) + ap.add_argument('--r', type=float, default=0.02) + ap.add_argument('--eps', type=float, default=0.1) + ap.add_argument('--compile', action='store_true') + ap.add_argument('--cuda-graph', action='store_true') + args = ap.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print('torch', torch.__version__, 'cuda_runtime', torch.version.cuda, 'cuda_available', torch.cuda.is_available(), 'device', device, flush=True) + if device.type == 'cuda': + print(torch.cuda.get_device_name(device), flush=True) + + torch.manual_seed(0) + blk, idx, y = make_block(args, device) + print('block', blk.C, blk.H, blk.T, 'B', idx.size(0), 'qknorm', getattr(blk,'qknorm',False), flush=True) + xin = blk.embed(idx).detach() + zs = LT.relax(blk, xin.clone(), xin, args.T1, args.eps) + print('zs norm', float(zs.norm()), flush=True) + + B=zs.size(0) + Z0 = torch.cat([zs, zs], 0) + zbar = 0.5*(Z0[:B]+Z0[B:]); zb2 = torch.cat([zbar,zbar],0) + # Need a nonzero v for the one-off J test. + vtest = torch.randn_like(zb2) * 1e-3 + _, Jv_ref = tf.jvp(lambda zz: blk.nc_force(zz), (zb2,), (vtest,)) + JTv_ref = tf.vjp(lambda zz: blk.nc_force(zz), zb2)[1](vtest)[0] + Jv_m, JTv_m = manual_nc_jvp_vjp_thick(blk, zb2, vtest) + print('manual_Jv cos', cosine(Jv_ref, Jv_m), 'maxrel', max_rel(Jv_m, Jv_ref), flush=True) + print('manual_JTv cos', cosine(JTv_ref, JTv_m), 'maxrel', max_rel(JTv_m, JTv_ref), flush=True) + + tf_step = make_tf_step(blk, zs, xin, y, args.r, args.eps) + man_step = make_manual_step(blk, zs, xin, y, args.r, args.eps) + with torch.no_grad(): + Z_tf = tf_step(Z0) + Z_man = man_step(Z0) + print('one_step manual vs tf cos', cosine(Z_tf, Z_man), 'max_abs', float((Z_tf-Z_man).abs().max()), 'maxrel', max_rel(Z_man, Z_tf), flush=True) + + # Compare a-select outputs for small T2. Full checkpoint on CPU is intentionally small T2. + with torch.no_grad(): + t0=time.time(); a_base,tb=H.holo_a_track(blk,zs,xin,y,args.r,args.T2,args.eps,K=max(1,args.T2)); sync(device); dt=time.time()-t0 + t0=time.time(); a_man,tm=run_loop_from_step(man_step,zs,args.r,args.T2,K=max(1,args.T2)); sync(device); dtm=time.time()-t0 + print('baseline_holo_a_track T2', args.T2, 't_best', tb, 'sec', dt, flush=True) + print('manual_loop T2', args.T2, 't_best', tm, 'sec', dtm, 'cos(a)', cosine(a_base,a_man), 'maxrel', max_rel(a_man,a_base), flush=True) + + if args.compile: + for name, step in [('tf_step_body', tf_step), ('manual_step_body', man_step)]: + try: + print('compile start', name, flush=True) + cstep = torch.compile(step, fullgraph=True, mode='reduce-overhead') + # compile on first invocation + with torch.no_grad(): + Zc = cstep(Z0) + sync(device) + print('compile ok', name, 'cos one_step', cosine(Z_tf if name=='tf_step_body' else Z_man, Zc), 'max_abs', float(((Z_tf if name=='tf_step_body' else Z_man)-Zc).abs().max()), flush=True) + if device.type == 'cuda': + t_e,_=time_call(lambda: step(Z0), device) + t_c,_=time_call(lambda: cstep(Z0), device) + print('timing', name, 'eager_ms', t_e*1000, 'compiled_ms', t_c*1000, 'speedup', t_e/t_c, flush=True) + except Exception as e: + print('compile FAIL', name, type(e).__name__, str(e)[:1000], flush=True) + traceback.print_exc(limit=8) + + if args.cuda_graph: + if device.type != 'cuda': + print('cuda graph skipped: torch.cuda.is_available() is False', flush=True) + else: + try: + static_Z = Z0.clone() + # warmup on a side stream to settle allocations + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + static_out = tf_step(static_Z) + torch.cuda.current_stream().wait_stream(s) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + static_out = tf_step(static_Z) + g.replay(); sync(device) + print('cuda graph capture ok tf_step_body', cosine(tf_step(Z0), static_out), flush=True) + except Exception as e: + print('cuda graph FAIL tf_step_body', type(e).__name__, str(e)[:1000], flush=True) + traceback.print_exc(limit=8) + +if __name__ == '__main__': + main() + +# NOTE: extra decomposed-layernorm helper kept below for reference; not used by main() above. diff --git a/ep_run/test_compile_aselect.py b/ep_run/test_compile_aselect.py new file mode 100644 index 0000000..a2601e4 --- /dev/null +++ b/ep_run/test_compile_aselect.py @@ -0,0 +1,23 @@ +import torch, time, math +import lt_ep_train as LT, holo_ep as H +torch.manual_seed(0) +blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 +ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda') +with torch.no_grad(): + for p,s in zip(blk.allp,ck['allp']): p.copy_(s.to('cuda')) +idx,y=LT.get_batch('train',24,256); xin=blk.embed(idx).detach() +zs=LT.relax(blk,xin.clone(),xin,150,0.1) +def T(fn): + fn(); torch.cuda.synchronize(); t0=time.time(); a,_=fn(); torch.cuda.synchronize(); return round((time.time()-t0)*1000),a +ms0,a0=T(lambda: H.holo_a_track(blk,zs,xin,y,0.02,80,0.1)) +print(f"uncompiled a-select (t2sel=80): {ms0} ms",flush=True) +orig=blk.nc_force +try: + blk.nc_force=torch.compile(blk.nc_force, mode='reduce-overhead') + ms1,a1=T(lambda: H.holo_a_track(blk,zs,xin,y,0.02,80,0.1)) + cos=float((a0.flatten()@a1.flatten())/(a0.norm()*a1.norm()+1e-20)) + print(f"compiled nc_force: {ms1} ms cos(a vs uncompiled)={cos:.4f}",flush=True) +except Exception as e: + print("compile FAIL:", type(e).__name__, str(e)[:120],flush=True) +blk.nc_force=orig +print("DONE",flush=True) diff --git a/ep_run/track_probe.py b/ep_run/track_probe.py new file mode 100644 index 0000000..2846745 --- /dev/null +++ b/ep_run/track_probe.py @@ -0,0 +1,77 @@ +import torch, math +import lt_ep_train as M +from pathlib import Path +import pickle +M.DD = Path('/tmp/lt_ep/data/tinystories') +M.vocab = pickle.load(open(M.DD/'meta.pkl','rb'))['vocab_size'] +from lt_ep_train import EQBlock, get_batch, bptt_step, relax +from holo_ep import holo_a_select2, rforce, rgrad_ce +import torch.func as tf + +def holo_a_track(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0): + """Common-mode-tracking AEP: linearize the antisymmetric correction at the instantaneous + common mode of the two phases — exact transposed differential dynamics, loose-tolerant, + no compounding linearization error.""" + B = zs.size(0) + Z = torch.cat([zs, zs], 0) + X2 = torch.cat([xin, xin], 0) + y2 = torch.cat([y, y], 0) + sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) + fnc = lambda zz: blk.nc_force(zz) + a_prev = a_best = None + inc_min, t_best = float('inf'), 0 + for t in range(1, T2max + 1): + with torch.no_grad(): + zbar = 0.5 * (Z[:B] + Z[B:]) + zb2 = torch.cat([zbar, zbar], 0) + f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel()) + v = (Z - zb2).contiguous() + _, Jv = tf.jvp(fnc, (zb2,), (v,)) + JTv = tf.vjp(fnc, zb2)[1](v)[0] + Z = Z + eps * (f - (Jv - JTv)) + if t % K == 0 or t == T2max: + a_t = (Z[B:] - Z[:B]) / (2 * r) + if not torch.isfinite(a_t).all(): + break + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + elif inc > exit_mult * inc_min and t >= 3 * K: + break + a_prev = a_t + if a_best is None: + a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r) + t_best = T2max + return a_best.detach(), t_best + +if __name__ == '__main__': + torch.manual_seed(0) + B, T, C, H = 8, 256, 256, 8 + blk = EQBlock(C, H, 256, T, attn_mode='thick') + ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt') + for p, w in zip(blk.allp, ck['allp']): + with torch.no_grad(): + p.copy_(w.to('cuda')) + idx, y = get_batch('train', B, T) + xin = blk.embed(idx).detach() + ref150 = bptt_step(blk, idx, y, 150, 0.1) + def flat(g): + keep = [p for p in blk.block if g.get(id(p)) is not None] + return torch.cat([g[id(p)].reshape(-1) for p in keep]) + v150 = flat(ref150) + def gfrom(zs, a_): + with torch.enable_grad(): + x2 = blk.embed(idx) + f = blk.force(zs.detach(), x2, cg=True) + return {id(p): g for p, g in zip(blk.block, torch.autograd.grad((a_*f).sum(), blk.block, allow_unused=True))} + z = xin.clone(); prev = 0 + for T1 in (75, 150, 600): + z = relax(blk, z, xin, T1 - prev, 0.1); prev = T1 + res = (relax(blk, z, xin, 1, 0.1) - z).norm().item() / z.norm().item() + for name, fn in (('frozen', holo_a_select2), ('track', holo_a_track)): + for T2m in (120, 300): + a, tb = fn(blk, z, xin, y, 0.02, T2m, 0.1) + va = flat(gfrom(z, a)) + c = (va @ v150 / (va.norm() * v150.norm() + 1e-12)).item() + print(f"T1={T1:>4} res={res:.1e} {name:>6} T2max={T2m:>3}: t_best={tb:>3} cos_vs150={c:.3f}", flush=True) diff --git a/ep_run/train.py b/ep_run/train.py new file mode 100644 index 0000000..90c5361 --- /dev/null +++ b/ep_run/train.py @@ -0,0 +1,183 @@ +"""Train a tiny char-level GPT on tinyshakespeare with softmax or sigmoid attn. + +Logs train/val loss to runs/<run_name>/log.jsonl (one JSON per line). +No checkpoints are saved (disk-conscious). +""" +import argparse +import json +import math +import os +import pickle +import time +from pathlib import Path + +import numpy as np +import torch + +from model import GPT, GPTConfig + + +def get_batch(split: str, data_dir: Path, block_size: int, batch_size: int, device: str): + fn = "train.bin" if split == "train" else "val.bin" + data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - 1, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) + return x.to(device, non_blocking=True), y.to(device, non_blocking=True) + + +@torch.no_grad() +def estimate_loss(model, data_dir, block_size, batch_size, device, eval_iters): + out = {} + model.eval() + for split in ("train", "val"): + losses = torch.zeros(eval_iters) + for k in range(eval_iters): + X, Y = get_batch(split, data_dir, block_size, batch_size, device) + _, loss = model(X, Y) + losses[k] = loss.item() + out[split] = losses.mean().item() + model.train() + return out + + +def lr_schedule(it, warmup_iters, lr_decay_iters, max_lr, min_lr): + if it < warmup_iters: + return max_lr * (it + 1) / (warmup_iters + 1) + if it > lr_decay_iters: + return min_lr + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) + return min_lr + coeff * (max_lr - min_lr) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--run_name", type=str, required=True) + p.add_argument("--attn_mode", type=str, default="softmax", choices=["softmax", "sigmoid"]) + p.add_argument("--sigmoid_bias_mode", type=str, default="neg_log_n", + choices=["zero", "neg_log_n", "learned"]) + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--data_dir", type=str, default="data/shakespeare_char") + p.add_argument("--out_dir", type=str, default="runs") + p.add_argument("--block_size", type=int, default=256) + p.add_argument("--batch_size", type=int, default=64) + p.add_argument("--n_layer", type=int, default=6) + p.add_argument("--n_head", type=int, default=6) + p.add_argument("--n_embd", type=int, default=384) + p.add_argument("--dropout", type=float, default=0.2) + p.add_argument("--max_iters", type=int, default=5000) + p.add_argument("--warmup_iters", type=int, default=100) + p.add_argument("--lr_decay_iters", type=int, default=5000) + p.add_argument("--max_lr", type=float, default=1e-3) + p.add_argument("--min_lr", type=float, default=1e-4) + p.add_argument("--weight_decay", type=float, default=0.1) + p.add_argument("--beta1", type=float, default=0.9) + p.add_argument("--beta2", type=float, default=0.99) + p.add_argument("--grad_clip", type=float, default=1.0) + p.add_argument("--eval_interval", type=int, default=250) + p.add_argument("--eval_iters", type=int, default=100) + p.add_argument("--log_interval", type=int, default=50) + p.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"]) + args = p.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + device = "cuda" if torch.cuda.is_available() else "cpu" + data_dir = Path(args.data_dir) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + vocab_size = meta["vocab_size"] + + run_dir = Path(args.out_dir) / args.run_name + run_dir.mkdir(parents=True, exist_ok=True) + log_path = run_dir / "log.jsonl" + cfg_path = run_dir / "config.json" + with open(cfg_path, "w") as f: + json.dump(vars(args) | {"vocab_size": vocab_size}, f, indent=2) + + cfg = GPTConfig( + block_size=args.block_size, + vocab_size=vocab_size, + n_layer=args.n_layer, + n_head=args.n_head, + n_embd=args.n_embd, + dropout=args.dropout, + attn_mode=args.attn_mode, + sigmoid_bias_mode=args.sigmoid_bias_mode, + ) + model = GPT(cfg).to(device) + n_params = model.num_params() + + # AdamW with weight-decay-free for 1D params (ln, embeddings, biases, sig_bias) + decay_params, nodecay_params = [], [] + for n, pr in model.named_parameters(): + if not pr.requires_grad: + continue + if pr.dim() >= 2: + decay_params.append(pr) + else: + nodecay_params.append(pr) + optimizer = torch.optim.AdamW( + [ + {"params": decay_params, "weight_decay": args.weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ], + lr=args.max_lr, + betas=(args.beta1, args.beta2), + fused=(device == "cuda"), + ) + + dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} + amp_dtype = dtype_map[args.dtype] + scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == "float16")) + + t0 = time.time() + + def log(record: dict): + record["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(record) + "\n") + + log({"event": "start", "params": n_params, "config": vars(args) | {"vocab_size": vocab_size}}) + print(f"[{args.run_name}] params={n_params/1e6:.2f}M device={device} dtype={args.dtype}") + + model.train() + for it in range(args.max_iters + 1): + lr = lr_schedule(it, args.warmup_iters, args.lr_decay_iters, args.max_lr, args.min_lr) + for g in optimizer.param_groups: + g["lr"] = lr + + if it % args.eval_interval == 0 or it == args.max_iters: + losses = estimate_loss(model, data_dir, args.block_size, args.batch_size, device, args.eval_iters) + log({"event": "eval", "iter": it, "train_loss": losses["train"], "val_loss": losses["val"], "lr": lr}) + print(f"[{args.run_name}] iter {it:5d} train {losses['train']:.4f} val {losses['val']:.4f} lr {lr:.4g}") + + if it == args.max_iters: + break + + X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) + with torch.amp.autocast(device_type="cuda", dtype=amp_dtype, enabled=(device == "cuda")): + _, loss = model(X, Y) + optimizer.zero_grad(set_to_none=True) + if args.dtype == "float16": + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + optimizer.step() + + if it % args.log_interval == 0: + log({"event": "step", "iter": it, "train_loss": loss.item(), "lr": lr}) + + log({"event": "done", "iter": args.max_iters}) + print(f"[{args.run_name}] done in {time.time()-t0:.1f}s") + + +if __name__ == "__main__": + main() diff --git a/ep_run/train_local.py b/ep_run/train_local.py new file mode 100644 index 0000000..2519cf9 --- /dev/null +++ b/ep_run/train_local.py @@ -0,0 +1,300 @@ +"""Local-learning sweep training on Shakespeare char LM with sigmoid transformer. + +Supported methods (via --method): + bp standard backprop (reference baseline) + fa Feedback Alignment: per-LocalLinear random fixed B replaces W.T in backward + sign_sym Sign-symmetric: per-LocalLinear sign(W) replaces W.T in backward + dfa Direct Feedback Alignment: each LocalLinear's .grad is overwritten with + (B_dfa @ e_L) outer (cached input). Embeddings/LN retain BP gradients. + +Reuses data/shakespeare_char/*.bin from Phase 1. +""" +import argparse +import json +import math +import os +import pickle +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F + +from local_layers import LocalLinear, apply_dfa_update, initialize_dfa_targets +from model_local import LocalGPT, LocalGPTConfig + + +def get_batch(split, data_dir, block_size, batch_size, device): + fn = "train.bin" if split == "train" else "val.bin" + data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - 1, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) + return x.to(device, non_blocking=True), y.to(device, non_blocking=True) + + +@torch.no_grad() +def estimate_loss(model, data_dir, block_size, batch_size, device, eval_iters): + out = {} + model.eval() + for split in ("train", "val"): + losses = torch.zeros(eval_iters) + for k in range(eval_iters): + X, Y = get_batch(split, data_dir, block_size, batch_size, device) + _, loss = model(X, Y) + losses[k] = loss.item() + out[split] = losses.mean().item() + model.train() + return out + + +def lr_schedule(it, warmup, decay_iters, max_lr, min_lr): + if it < warmup: + return max_lr * (it + 1) / (warmup + 1) + if it > decay_iters: + return min_lr + coeff = 0.5 * (1.0 + math.cos(math.pi * (it - warmup) / max(1, decay_iters - warmup))) + return min_lr + coeff * (max_lr - min_lr) + + +def compute_analytical_e_L(logits, targets, vocab_size): + """Closed-form gradient of mean cross-entropy w.r.t. logits. + + logits: (B, T, V), targets: (B, T). Returns e_L shape (B, T, V). + For CE with reduction='mean' over (B*T): dL/dlogits = (softmax(logits) - onehot(y)) / (B*T) + """ + probs = F.softmax(logits, dim=-1) + onehot = F.one_hot(targets, num_classes=vocab_size).float() + N = targets.numel() + return (probs - onehot) / N + + +def compute_alignment_diagnostics(model, batch_x, batch_y, method, vocab_size): + """Per-LocalLinear gradient cosine to BP + (FA only) ‖B − W‖_F / ‖W‖_F. + + Cosine signs the "functional alignment": 1 = method grad matches BP direction, + 0 = orthogonal, negative = wrong direction. Per-layer dict keyed by module name. + + Two forward-backward passes: first in the method's own mode (grabs method grads), + second with all LocalLinear temporarily switched to 'bp' (grabs BP grads). + Restores method before returning. Runs in eval mode to disable dropout so both + passes see identical activations (otherwise BP would show cosine < 1 vs itself). + """ + out = {"grad_cos": {}, "fa_offset": {}} + was_training = model.training + model.eval() + + # --- Pass 1: method's backward --- + model.zero_grad(set_to_none=True) + logits, loss = model(batch_x, batch_y) + loss.backward() + if method == "dfa": + with torch.no_grad(): + e_L = compute_analytical_e_L(logits.detach(), batch_y, vocab_size) + apply_dfa_update(model, e_L) + + method_grads = {} + for name, m in model.named_modules(): + if isinstance(m, LocalLinear) and m.weight.grad is not None: + method_grads[name] = m.weight.grad.detach().clone() + + # FA-specific metric: distance between fixed B and current W + if method == "fa": + for name, m in model.named_modules(): + if isinstance(m, LocalLinear) and m.method == "fa": + diff = m.B - m.weight + out["fa_offset"][name] = (diff.norm() / (m.weight.norm() + 1e-9)).item() + + # --- Pass 2: BP backward via temporary method switch --- + method_backup = {} + for m in model.modules(): + if isinstance(m, LocalLinear): + method_backup[id(m)] = m.method + m.method = "bp" + + model.zero_grad(set_to_none=True) + _, loss_bp = model(batch_x, batch_y) + loss_bp.backward() + + for name, m in model.named_modules(): + if isinstance(m, LocalLinear) and name in method_grads: + g_bp = m.weight.grad.detach() + g_method = method_grads[name] + cos = F.cosine_similarity( + g_bp.flatten().unsqueeze(0), + g_method.flatten().unsqueeze(0), + ).item() + out["grad_cos"][name] = cos + + # Restore method + for m in model.modules(): + if isinstance(m, LocalLinear) and id(m) in method_backup: + m.method = method_backup[id(m)] + + model.zero_grad(set_to_none=True) + if was_training: + model.train() + return out + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--method", required=True, choices=["bp", "fa", "dfa", "sign_sym"]) + p.add_argument("--run_name", type=str, required=True) + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--data_dir", type=str, default="data/shakespeare_char") + p.add_argument("--out_dir", type=str, default="runs_local") + p.add_argument("--block_size", type=int, default=256) + p.add_argument("--batch_size", type=int, default=64) + p.add_argument("--n_layer", type=int, default=6) + p.add_argument("--n_head", type=int, default=6) + p.add_argument("--n_embd", type=int, default=384) + p.add_argument("--dropout", type=float, default=0.2) + p.add_argument("--max_iters", type=int, default=5000) + p.add_argument("--warmup_iters", type=int, default=100) + p.add_argument("--lr_decay_iters", type=int, default=5000) + p.add_argument("--max_lr", type=float, default=1e-3) + p.add_argument("--min_lr", type=float, default=1e-4) + p.add_argument("--weight_decay", type=float, default=0.1) + p.add_argument("--beta1", type=float, default=0.9) + p.add_argument("--beta2", type=float, default=0.99) + p.add_argument("--grad_clip", type=float, default=1.0) + p.add_argument("--eval_interval", type=int, default=250) + p.add_argument("--eval_iters", type=int, default=100) + p.add_argument("--log_interval", type=int, default=50) + p.add_argument("--attn_mode", type=str, default="sigmoid", choices=["softmax", "sigmoid"]) + p.add_argument("--sigmoid_bias_mode", type=str, default="neg_log_n") + p.add_argument("--ste_sigmoid", action="store_true", help="STE on sigmoid attention (skip A(1-A) derivative)") + p.add_argument("--ste_gelu", action="store_true", help="STE on GELU (skip gelu' derivative)") + p.add_argument("--ln_mode", type=str, default="bp", choices=["bp", "ste", "center_scale", "projected"], + help="LN backward: bp=standard, ste=identity, center_scale=mean-center+1/σ, projected=full surrogate") + p.add_argument("--freeze_emb", action="store_true", help="Freeze token + position embeddings") + p.add_argument("--fuse_attn_local", action="store_true", help="Fuse softmax+A@V with local backward (no lateral sum)") + args = p.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + data_dir = Path(args.data_dir) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + vocab_size = meta["vocab_size"] + + run_dir = Path(args.out_dir) / args.run_name + run_dir.mkdir(parents=True, exist_ok=True) + log_path = run_dir / "log.jsonl" + log_path.write_text("") + with open(run_dir / "config.json", "w") as f: + json.dump(vars(args) | {"vocab_size": vocab_size}, f, indent=2) + + cfg = LocalGPTConfig( + block_size=args.block_size, + vocab_size=vocab_size, + n_layer=args.n_layer, + n_head=args.n_head, + n_embd=args.n_embd, + dropout=args.dropout, + attn_mode=args.attn_mode, + sigmoid_bias_mode=args.sigmoid_bias_mode, + method=args.method, + ste_sigmoid=args.ste_sigmoid, + ste_gelu=args.ste_gelu, + ln_mode=args.ln_mode, + freeze_emb=args.freeze_emb, + fuse_attn_local=args.fuse_attn_local, + ) + model = LocalGPT(cfg).to(device) + n_params = model.num_params() + + if args.method == "dfa": + initialize_dfa_targets(model, vocab_size) + + # Build optimizer. For all methods, gather params with weight decay convention. + decay_params, nodecay_params = [], [] + for n, pr in model.named_parameters(): + if not pr.requires_grad: + continue + if pr.dim() >= 2: + decay_params.append(pr) + else: + nodecay_params.append(pr) + optimizer = torch.optim.AdamW( + [ + {"params": decay_params, "weight_decay": args.weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ], + lr=args.max_lr, + betas=(args.beta1, args.beta2), + fused=(device == "cuda"), + ) + + t0 = time.time() + + def log(rec): + rec["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(rec) + "\n") + + n_localinear = sum(1 for m in model.modules() if isinstance(m, LocalLinear)) + log({ + "event": "start", "method": args.method, "params": n_params, + "n_localinear": n_localinear, "vocab_size": vocab_size, + "config": vars(args), + }) + print(f"[{args.run_name}] method={args.method} params={n_params/1e6:.2f}M LocalLinear={n_localinear}") + + model.train() + for it in range(args.max_iters + 1): + lr = lr_schedule(it, args.warmup_iters, args.lr_decay_iters, args.max_lr, args.min_lr) + for g in optimizer.param_groups: + g["lr"] = lr + + if it % args.eval_interval == 0 or it == args.max_iters: + losses = estimate_loss(model, data_dir, args.block_size, args.batch_size, device, args.eval_iters) + # Alignment diagnostic on a fresh training batch + X_diag, Y_diag = get_batch("train", data_dir, args.block_size, args.batch_size, device) + align = compute_alignment_diagnostics(model, X_diag, Y_diag, args.method, vocab_size) + log({ + "event": "eval", "iter": it, + "train_loss": losses["train"], "val_loss": losses["val"], "lr": lr, + "grad_cos": align["grad_cos"], "fa_offset": align["fa_offset"], + }) + # Summary for print + if align["grad_cos"]: + cos_vals = list(align["grad_cos"].values()) + cos_mean = sum(cos_vals) / len(cos_vals) + cos_min = min(cos_vals) + print(f"[{args.run_name}] iter {it:5d} train {losses['train']:.4f} val {losses['val']:.4f} " + f"grad_cos μ={cos_mean:.3f} min={cos_min:.3f} lr {lr:.4g}") + else: + print(f"[{args.run_name}] iter {it:5d} train {losses['train']:.4f} val {losses['val']:.4f} lr {lr:.4g}") + + if it == args.max_iters: + break + + X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) + logits, loss = model(X, Y) + + optimizer.zero_grad(set_to_none=True) + loss.backward() + + if args.method == "dfa": + # Overwrite LocalLinear .grad with DFA-computed updates (using cached inputs from forward) + with torch.no_grad(): + e_L = compute_analytical_e_L(logits.detach(), Y, vocab_size) + apply_dfa_update(model, e_L) + + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + optimizer.step() + + if it % args.log_interval == 0: + log({"event": "step", "iter": it, "train_loss": loss.item(), "lr": lr}) + + log({"event": "done", "iter": args.max_iters}) + print(f"[{args.run_name}] done in {time.time()-t0:.1f}s") + + +if __name__ == "__main__": + main() diff --git a/ep_run/train_local_ce.py b/ep_run/train_local_ce.py new file mode 100644 index 0000000..b8f790a --- /dev/null +++ b/ep_run/train_local_ce.py @@ -0,0 +1,580 @@ +"""Local CE exit training — each block gets a vocab-space CE loss via shared unembedding. + +Each block l computes: + z_l = W_U @ T_l(h_l) (local logits via shared unembedding + optional translator) + L_l = λ_gt * CE(z_l, y) + λ_kd * τ² * KL(sg(p_L^τ) || p_l^τ) + +Forward weights updated per-block via local CE gradient (intra-block only). +No inter-block chain rule. Fused attention backward within each block. + +This replaces the hidden-space MSE target-matching that failed at scale. +""" +import argparse +import json +import math +import pickle +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model_local import LocalGPTConfig, LocalBlock, LocalLinear, _make_ln +from factorized_exit import FactorizedExitHead, ExactParallelExitHead +from local_layers import initialize_dfa_block_targets, apply_dfa_block_update + + +def get_batch(split, data_dir, block_size, batch_size, device, n_pred=1): + fn = "train.bin" if split == "train" else "val.bin" + data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - n_pred, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) + if n_pred == 1: + y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) + return x.to(device, non_blocking=True), y.to(device, non_blocking=True) + # n_pred > 1: targets shape (B, T, n_pred). Y[..., k-1] = next-k target. + y_multi = torch.stack([ + torch.stack([ + torch.from_numpy(data[i + k : i + k + block_size].astype(np.int64)) + for k in range(1, n_pred + 1) + ], dim=-1) + for i in ix + ]) + return x.to(device, non_blocking=True), y_multi.to(device, non_blocking=True) + + +class LowRankTranslator(nn.Module): + """T_l(h) = h + A @ B @ h + b. Low-rank affine residual translator.""" + def __init__(self, d_model, rank=32): + super().__init__() + self.A = nn.Parameter(torch.zeros(d_model, rank)) + self.B = nn.Parameter(torch.zeros(rank, d_model)) + self.bias = nn.Parameter(torch.zeros(d_model)) + nn.init.normal_(self.A, std=0.01) + nn.init.normal_(self.B, std=0.01) + + def forward(self, h): + return h + h @ self.B.T @ self.A.T + self.bias + + +class LocalCETransformer(nn.Module): + """Transformer with per-block local CE exits via shared unembedding.""" + + def __init__(self, config: LocalGPTConfig, translator_rank: int = 0, n_pred_tokens: int = 1, + shared_blocks: bool = False): + super().__init__() + self.config = config + self.n_pred_tokens = n_pred_tokens + self.shared_blocks = shared_blocks + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Embedding(config.block_size, config.n_embd) + self.drop = nn.Dropout(config.dropout) + if shared_blocks: + # Universal Transformer: one block applied n_layer times. + # All entries point to the SAME module — gradient accumulates from all "depths". + shared = LocalBlock(config) + self.blocks = nn.ModuleList([shared for _ in range(config.n_layer)]) + else: + self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)]) + self.ln_f = _make_ln(config) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Auxiliary unembedding heads for next-2..next-N prediction (multi-token training). + # Used only as gradient-source heads at training time; inference still uses self.head. + if n_pred_tokens > 1: + self.aux_heads = nn.ModuleList([ + nn.Linear(config.n_embd, config.vocab_size, bias=False) + for _ in range(n_pred_tokens - 1) + ]) + else: + self.aux_heads = None + + # Per-block translators (logit lens = rank 0 = identity) + if translator_rank > 0: + self.translators = nn.ModuleList([ + LowRankTranslator(config.n_embd, translator_rank) + for _ in range(config.n_layer) + ]) + else: + self.translators = None + + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"): + nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, LocalLinear)): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + if getattr(m, "bias", None) is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + + def forward_activations(self, idx): + B, T = idx.shape + pos = torch.arange(T, device=idx.device) + h = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) + activations = [h] + for block in self.blocks: + h = block(h) + activations.append(h) + return activations + + def local_logits(self, h, layer_idx): + """h → local logits via optional translator + shared unembedding.""" + if self.translators is not None: + h = self.translators[layer_idx](h) + return F.linear(h, self.head.weight) # shared W_U, no separate head + + def final_logits(self, h): + return self.head(self.ln_f(h)) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--run_name", type=str, required=True) + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--data_dir", type=str, default="data/shakespeare_char") + p.add_argument("--out_dir", type=str, default="runs_local") + p.add_argument("--block_size", type=int, default=256) + p.add_argument("--batch_size", type=int, default=64) + p.add_argument("--n_layer", type=int, default=6) + p.add_argument("--n_head", type=int, default=6) + p.add_argument("--n_embd", type=int, default=384) + p.add_argument("--dropout", type=float, default=0.2) + p.add_argument("--max_iters", type=int, default=5000) + p.add_argument("--warmup_iters", type=int, default=100) + p.add_argument("--max_lr", type=float, default=1e-3) + p.add_argument("--min_lr", type=float, default=1e-4) + p.add_argument("--eval_interval", type=int, default=250) + p.add_argument("--eval_iters", type=int, default=100) + p.add_argument("--log_interval", type=int, default=50) + p.add_argument("--attn_mode", type=str, default="softmax") + p.add_argument("--translator_rank", type=int, default=0, help="0=identity (logit lens), >0=low-rank affine") + p.add_argument("--kd_weight", type=float, default=1.0, help="weight for KL distillation from final layer") + p.add_argument("--kd_temp", type=float, default=2.0, help="temperature for KD") + p.add_argument("--gt_weight", type=float, default=1.0, help="weight for ground-truth CE") + p.add_argument("--nbr_weight", type=float, default=0.0, help="weight for neighbor KL (sg(p_{l+1}) || p_l)") + p.add_argument("--layer_weighting", type=str, default="uniform", choices=["uniform", "linear"], + help="per-layer loss weight: uniform=all 1.0, linear=l/L") + p.add_argument("--bp_free_exit", type=str, default="none", + choices=["none", "dense", "hybrid", "parallel_only", "parallel_gold", "parallel_topmass"], + help="BP-free exit: none=W_U^T, dense/hybrid=compressor, parallel_*=exact parallel term") + p.add_argument("--exit_rank", type=int, default=128, help="rank for BP-free exit compressor") + p.add_argument("--exit_rank_exact", type=int, default=32, help="exact rank for hybrid compressor") + p.add_argument("--exit_topk", type=int, default=8, help="top-k for hybrid compressor") + p.add_argument("--exit_residual_rank", type=int, default=32, + help="residual_rank for ExactParallelExitHead (parallel_gold/topmass): code dim for h-perp residual") + p.add_argument("--intra_block_method", type=str, default="bp", choices=["bp", "fa", "sign_sym", "dfa_block"], + help="intra-block: bp=W^T, fa=seq random B, sign_sym=sign(W)·rescale, dfa_block=direct from block-output-error") + p.add_argument("--mlp_topk", type=int, default=0, + help="if >0, apply hard k-WTA to MLP hidden activation (4*n_embd dim)") + p.add_argument("--resid_topk", type=int, default=0, + help="if >0, apply hard k-WTA to residual stream output of each block (n_embd dim)") + p.add_argument("--vq_codes", type=int, default=0, + help="if >0, apply directional VQ to residual stream at each block (K codebook entries, frozen)") + p.add_argument("--subspace_rank", type=int, default=0, + help="if >0, project residual stream to fixed r-dim orthonormal subspace at each block") + p.add_argument("--subspace_per_layer", action="store_true", + help="use DIFFERENT random Q per layer (ablation: tests if shared Q is necessary)") + p.add_argument("--fa_init_sign", action="store_true", + help="init FA's fixed B as sign(W_init)*rescale instead of random (frozen sign_sym)") + p.add_argument("--shared_blocks", action="store_true", + help="Universal Transformer: all blocks share the same parameters (single block applied n_layer times)") + p.add_argument("--fa_init", type=str, default="gaussian", + choices=["gaussian", "orthogonal", "ortho_he", "sparse"], + help="FA's fixed B init mode (gaussian=Lillicrap, orthogonal=JL-isometric, ortho_he=He-init backward, sparse=structured)") + p.add_argument("--fa_sparse_k", type=int, default=0, + help="for fa_init=sparse: non-zero entries per row (0 = auto = in_features/16)") + p.add_argument("--gated_blocks", action="store_true", + help="Path IV: learned per-block residual gates (α_attn, α_mlp). Lets useless layers self-deactivate.") + p.add_argument("--progression_targets", action="store_true", + help="Path I: each block l predicts next-(l+1) token (progressive prediction horizons per layer)") + p.add_argument("--weight_normalize", action="store_true", + help="Meta-PCN style WN: after each optimizer step, normalize LocalLinear's W by (sqrt(m)+sqrt(n))*std(W) to keep ||W||_2 ~= 1") + p.add_argument("--pc_inference", type=int, default=0, + help="Predictive coding inference steps T (T=0 disables PC mode, uses standard local CE)") + p.add_argument("--pc_inference_lr", type=float, default=0.1, + help="Inference step size η for PC z updates") + p.add_argument("--pc_top_weight", type=float, default=1.0, + help="Weight of top-down CE term in PC energy F") + p.add_argument("--fa_grape", action="store_true", + help="GrAPE: per-step JVP-based cosine alignment of FA's B toward true Jacobian (Caillon et al. 2026)") + p.add_argument("--fa_grape_lr", type=float, default=0.01, + help="Learning rate for GrAPE B alignment update") + p.add_argument("--fa_grape_n_probe", type=int, default=32, + help="Number of probe samples for JVP rank-1 Jacobian estimate") + p.add_argument("--save_ckpt", action="store_true", + help="save final model state to run_dir/ckpt.pt for downstream probing") + p.add_argument("--n_pred_tokens", type=int, default=1, + help="multi-token prediction: predict next-1..next-N (N=1 disables, default)") + p.add_argument("--aux_weight", type=float, default=0.3, + help="weight for aux next-k losses (k=2..N). Primary next-1 always weight 1.0.") + args = p.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + data_dir = Path(args.data_dir) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + vocab_size = meta["vocab_size"] + + run_dir = Path(args.out_dir) / args.run_name + run_dir.mkdir(parents=True, exist_ok=True) + log_path = run_dir / "log.jsonl" + log_path.write_text("") + + cfg = LocalGPTConfig( + block_size=args.block_size, vocab_size=vocab_size, + n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, + dropout=args.dropout, attn_mode=args.attn_mode, + method=args.intra_block_method, + fuse_attn_local=True, ste_gelu=True, ln_mode="center_scale", + mlp_topk=args.mlp_topk, resid_topk=args.resid_topk, + vq_codes=args.vq_codes, subspace_rank=args.subspace_rank, + fa_init_mode=args.fa_init, fa_sparse_k=args.fa_sparse_k, + gated_blocks=args.gated_blocks, + fa_grape=args.fa_grape, fa_grape_n_probe=args.fa_grape_n_probe, + ) + model = LocalCETransformer(cfg, translator_rank=args.translator_rank, + n_pred_tokens=args.n_pred_tokens, + shared_blocks=args.shared_blocks).to(device) + + # Frozen sign_sym: replace FA's random B with sign(W_init)*rescale, then freeze. + # B is still a fixed buffer (BP-free by definition B), just structured init. + if args.fa_init_sign and args.intra_block_method == "fa": + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LocalLinear) and module.method == "fa": + scale = module.weight.norm() / (module.weight.numel() ** 0.5 + 1e-8) + module.B.copy_(torch.sign(module.weight) * scale) + + # Per-layer different Q ablation: replace each block's shared-seed subspace + # with independently-seeded subspace (tests if shared Q is the mechanism) + if args.subspace_per_layer and args.subspace_rank > 0: + from model_local import FrozenSubspace + for i, block in enumerate(model.blocks): + block.subspace = FrozenSubspace(args.n_embd, args.subspace_rank, seed=1000 + i).to(device) + + # Initialize DFA-block targets if needed + if args.intra_block_method == "dfa_block": + initialize_dfa_block_targets(model, args.n_embd) + + # BP-free exit heads (one per block) + exit_heads = None + if args.bp_free_exit in ("dense", "hybrid"): + exit_heads = nn.ModuleList([ + FactorizedExitHead( + args.n_embd, vocab_size, mode=args.bp_free_exit, + rank=args.exit_rank, rank_exact=args.exit_rank_exact, topk=args.exit_topk, + ) for _ in range(cfg.n_layer) + ]).to(device) + elif args.bp_free_exit.startswith("parallel"): + exit_heads = nn.ModuleList([ + ExactParallelExitHead( + args.n_embd, vocab_size, mode=args.bp_free_exit, + residual_rank=args.exit_residual_rank, + ) for _ in range(cfg.n_layer) + ]).to(device) + + n_params = sum(p.numel() for p in model.parameters()) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr, weight_decay=0.1) + + t0 = time.time() + + def log(rec): + rec["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(rec) + "\n") + + log({"event": "start", "method": "local_ce", "params": n_params, + "translator_rank": args.translator_rank, "config": vars(args)}) + print(f"[{args.run_name}] local_ce, params={n_params/1e6:.2f}M, translator_rank={args.translator_rank}") + + def lr_schedule(it): + if it < args.warmup_iters: + return args.max_lr * (it + 1) / (args.warmup_iters + 1) + decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) / + max(1, args.max_iters - args.warmup_iters))) + return args.min_lr + decay * (args.max_lr - args.min_lr) + + @torch.no_grad() + def eval_loss(): + model.eval() + losses = torch.zeros(args.eval_iters) + for k in range(args.eval_iters): + X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device) + acts = model.forward_activations(X) + logits = model.final_logits(acts[-1]) + loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) + losses[k] = loss.item() + model.train() + return losses.mean().item() + + model.train() + for it in range(args.max_iters + 1): + lr = lr_schedule(it) + for g in optimizer.param_groups: + g["lr"] = lr + + if it % args.eval_interval == 0 or it == args.max_iters: + val = eval_loss() + log({"event": "eval", "iter": it, "val_loss": val, "lr": lr}) + print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}") + + if it == args.max_iters: + break + + # Determine n_pred for batch fetch + # - n_pred_tokens > 1: multi-token MTP aux losses (each block predicts N targets via N heads) + # - progression_targets: each block l predicts next-(l+1) (so need n_pred = n_layer) + if args.n_pred_tokens > 1: + n_pred = args.n_pred_tokens + elif args.progression_targets: + n_pred = cfg.n_layer + else: + n_pred = 1 + + if n_pred > 1: + X, Y_multi = get_batch("train", data_dir, args.block_size, args.batch_size, device, + n_pred=n_pred) + Y = Y_multi[..., 0] # (B, T) — next-1 target for default + else: + X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) + Y_multi = None + + # ============================================================ + # PC mode: predictive coding inference + Hebbian-style updates + # (when --pc_inference T > 0, replaces standard local CE) + # ============================================================ + if args.pc_inference > 0: + optimizer.zero_grad() + Y_flat = Y.view(-1) + + # 1. Forward init (no autograd graph during init) + with torch.no_grad(): + init_acts = model.forward_activations(X) + # z[0] = embedding, clamped (no grad). z[1..L] = block outputs, evolve. + z = [init_acts[0].detach()] + for l in range(1, len(init_acts)): + z.append(init_acts[l].detach().clone().requires_grad_(True)) + + # 2. PC inference: T iterations of z updates via ∂F/∂z + # F = Σ_{l<L} (1/2) ||z_l - block_{l-1}(z_{l-1})||² / d + λ·CE(W_U @ z_L, y) + # Skip PE_L (Meta-PCN trick: CE replaces last-layer squared error) + # Use mean over hidden dim (per-token PE²) for scale-invariance. + for t in range(args.pc_inference): + F_energy = 0.0 + for l in range(1, len(z) - 1): # l = 1..L-1, skip PE_L + z_hat = model.blocks[l - 1](z[l - 1]) + pe = z[l] - z_hat + F_energy = F_energy + 0.5 * (pe ** 2).mean() # mean over (B,T,d) → scale-invariant + # Top-down: CE at z_L (replaces PE_L per Meta-PCN convention) + logits_top = F.linear(z[-1], model.head.weight) + CE_top = F.cross_entropy(logits_top.view(-1, vocab_size), Y_flat) + F_total = F_energy + args.pc_top_weight * CE_top + # Compute ∂F/∂z[1..L] (FA-flavored due to LocalLinear FA backward inside blocks) + grads = torch.autograd.grad(F_total, z[1:], create_graph=False, retain_graph=False) + # SGD update on z's + with torch.no_grad(): + new_z = [z[0]] + for i, g in enumerate(grads): + new_z.append((z[i + 1] - args.pc_inference_lr * g).detach().requires_grad_(True)) + z = new_z + + # 3. Weight update via per-block PE loss using converged z's + # For block l-1: minimize ||sg(z_l) - block_{l-1}(sg(z_{l-1}))||² + # backward gives FA-flavored W gradients (Hebbian-equivalent at equilibrium) + total_loss = 0.0 + for l in range(1, len(z)): + z_hat = model.blocks[l - 1](z[l - 1].detach()) + target = z[l].detach() + pe_loss = 0.5 * ((target - z_hat) ** 2).mean() + pe_loss.backward() + total_loss += pe_loss.item() + + # Final head + ln_f via CE on converged z[-1] + final_z = model.final_logits(z[-1].detach()) + head_loss = F.cross_entropy(final_z.view(-1, vocab_size), Y_flat) + head_loss.backward() + total_loss += head_loss.item() + + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + + # Optional WN + if args.weight_normalize: + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LocalLinear): + m, n = module.weight.shape + sigma_w = module.weight.std() + scale = (m ** 0.5 + n ** 0.5) * sigma_w + if scale > 1e-8: + module.weight.div_(scale) + + if it % args.log_interval == 0: + log({"event": "step", "iter": it, + "total_loss": total_loss / (cfg.n_layer + 1), + "head_loss": head_loss.item(), "lr": lr}) + continue + # ============================================================ + # End PC mode; standard local CE follows + # ============================================================ + + # Forward: compute all activations + activations = model.forward_activations(X) + + # Final logits (for KD teacher + eval) + with torch.no_grad(): + final_logits = model.final_logits(activations[-1].detach()) + teacher_probs = F.softmax(final_logits / args.kd_temp, dim=-1) + + # Per-block local CE losses (no inter-block gradient) + optimizer.zero_grad() + total_loss = 0.0 + Y_flat = Y.view(-1) + + # Neighbor KL teachers are computed on-the-fly inside the per-block loop + # (avoid pre-computing all 6 × (B,T,V) tensors which OOM at V=50k) + + for l in range(cfg.n_layer): + # Block l: h_l → block → h_{l+1} + h_l = activations[l] if l == 0 else activations[l].detach() + # For dfa_block mode, need h_lp1 with retain_grad to capture block-output-error + if args.intra_block_method == "dfa_block": + h_lp1 = model.blocks[l](h_l) + h_lp1.retain_grad() + else: + h_lp1 = model.blocks[l](h_l) + + # Path I: progression targets — block l predicts next-(l+1) instead of next-1 + if args.progression_targets and Y_multi is not None: + Y_block = Y_multi[..., l] # (B, T) — block l's specific target + else: + Y_block = Y # default: all blocks predict next-1 + Y_block_flat = Y_block.reshape(-1) + + # Local logits via shared unembedding (exact or BP-free) + if exit_heads is not None: + local_z = exit_heads[l](h_lp1, model.head.weight, Y_block) + else: + local_z = model.local_logits(h_lp1, l) + local_z_flat = local_z.view(-1, vocab_size) + + # Per-layer weight + if args.layer_weighting == "linear": + layer_w = (l + 1) / cfg.n_layer + else: + layer_w = 1.0 + + # Ground-truth CE (uses Y_block_flat which respects progression mode) + loss_gt = F.cross_entropy(local_z_flat, Y_block_flat) + + # KD from final layer (skip when both kd_weight and nbr_weight are 0 to save 3.3GB/block) + loss_kd = 0.0 + loss_nbr = 0.0 + if args.kd_weight > 0 or args.nbr_weight > 0: + local_log_probs = F.log_softmax(local_z / args.kd_temp, dim=-1) + if args.kd_weight > 0: + loss_kd = F.kl_div( + local_log_probs.view(-1, vocab_size), + teacher_probs.view(-1, vocab_size), + reduction="batchmean", + ) * (args.kd_temp ** 2) + # Neighbor KL: match next block's prediction (stop-grad), computed on-the-fly + if args.nbr_weight > 0 and l < cfg.n_layer - 1: + with torch.no_grad(): + nbr_z = model.local_logits(activations[l + 2].detach(), l + 1) + nbr_probs = F.softmax(nbr_z / args.kd_temp, dim=-1) + del nbr_z + loss_nbr = F.kl_div( + local_log_probs.view(-1, vocab_size), + nbr_probs.view(-1, vocab_size), + reduction="batchmean", + ) * (args.kd_temp ** 2) + del nbr_probs + del local_log_probs + + # Multi-token aux losses: predict next-2..next-N via aux_heads + # Each aux head provides an independent gradient direction (different W_k column space). + # Reuses the same exit_heads[l] (shared codebook) but with different shared_weight + targets. + loss_aux = 0.0 + if args.n_pred_tokens > 1 and args.aux_weight > 0 and model.aux_heads is not None: + for k_idx, aux_head in enumerate(model.aux_heads): + Y_k = Y_multi[..., k_idx + 1] # next-(k_idx+2) target + if exit_heads is not None: + z_k = exit_heads[l](h_lp1, aux_head.weight, Y_k) + else: + z_k = F.linear(h_lp1, aux_head.weight) + loss_k = F.cross_entropy(z_k.view(-1, vocab_size), Y_k.reshape(-1)) + loss_aux = loss_aux + loss_k + loss_aux = loss_aux * args.aux_weight / (args.n_pred_tokens - 1) + + block_loss = layer_w * ( + args.gt_weight * loss_gt + + args.kd_weight * loss_kd + + args.nbr_weight * loss_nbr + + loss_aux + ) + block_loss.backward() + + # For dfa_block: overwrite intra-block linears' .grad using block-output-error + if args.intra_block_method == "dfa_block" and h_lp1.grad is not None: + with torch.no_grad(): + apply_dfa_block_update(model.blocks[l], h_lp1.grad) + + total_loss += block_loss.item() + + # Also train head + ln_f via final CE + h_L_det = activations[-1].detach() + final_z = model.final_logits(h_L_det) + head_loss = F.cross_entropy(final_z.view(-1, vocab_size), Y_flat) + head_loss.backward() + total_loss += head_loss.item() + + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + + # GrAPE: per-step alignment of FA's B toward Jacobian via JVP probes (forward-only) + if args.fa_grape: + for module in model.modules(): + if isinstance(module, LocalLinear) and getattr(module, "_fa_grape", False): + module.grape_align_step(lr_b=args.fa_grape_lr) + + # Meta-PCN style weight normalization: rescale each LocalLinear's W to have ||W||_2 ~= 1 + # via random matrix theory bound ||W||_2 ~= (sqrt(m) + sqrt(n)) * std(W). + # Only normalizes LocalLinear W (the trained weight); leaves B (fixed buffer) untouched. + if args.weight_normalize: + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LocalLinear): + m, n = module.weight.shape + sigma_w = module.weight.std() + scale = (m ** 0.5 + n ** 0.5) * sigma_w + if scale > 1e-8: + module.weight.div_(scale) + + if it % args.log_interval == 0: + log({"event": "step", "iter": it, "total_loss": total_loss / (cfg.n_layer + 1), + "head_loss": head_loss.item(), "lr": lr}) + + if args.save_ckpt: + ckpt_path = run_dir / "ckpt.pt" + torch.save({ + "model_state": model.state_dict(), + "config": vars(cfg), + "args": vars(args), + "vocab_size": vocab_size, + }, ckpt_path) + log({"event": "save_ckpt", "path": str(ckpt_path)}) + print(f"[{args.run_name}] saved ckpt to {ckpt_path}") + + +if __name__ == "__main__": + main() diff --git a/ep_run/train_recon.py b/ep_run/train_recon.py new file mode 100644 index 0000000..d180cb2 --- /dev/null +++ b/ep_run/train_recon.py @@ -0,0 +1,322 @@ +"""Reconstruction-based (DTP-style) training for local transformer. + +Each transformer block l has: + - Forward function f_l: h_l → h_{l+1} (standard transformer block) + - Feedback module g_l: h_{l+1} → ĥ_l (learned reconstruction, linear) + +Training loop per step: + 1. Forward pass: compute h_0, h_1, ..., h_L + 2. Top target: target_L = h_L - η_target * ∂L/∂h_L + 3. Propagate targets backward via g_l: + target_l = h_l + g_l(target_{l+1}) - g_l(h_{l+1}) (difference target prop) + 4. Train feedback g_l: minimize reconstruction loss (DRL-style with noise) + 5. Train forward f_l: minimize ||f_l(h_l) - target_{l+1}||² (local loss) + Within each block, attention uses fused backward, LN uses center_scale, GELU uses STE. + +No random matrices. No weight transport. No inter-block chain rule. +""" +import argparse +import json +import math +import pickle + +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model_local import LocalGPT, LocalGPTConfig, SoftmaxValueMixLocalFn + + +def get_batch(split, data_dir, block_size, batch_size, device): + fn = "train.bin" if split == "train" else "val.bin" + data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - 1, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) + return x.to(device, non_blocking=True), y.to(device, non_blocking=True) + + +class FeedbackModule(nn.Module): + """g_l: h_{l+1} → ĥ_l. Linear reconstruction module.""" + def __init__(self, d_model): + super().__init__() + self.linear = nn.Linear(d_model, d_model, bias=False) + nn.init.eye_(self.linear.weight) # init as identity (good starting point) + + def forward(self, h): + return self.linear(h) + + +class ReconTransformer(nn.Module): + """Transformer with per-block feedback modules for reconstruction-based training.""" + + def __init__(self, config: LocalGPTConfig): + super().__init__() + self.config = config + # Forward model (standard transformer) + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Embedding(config.block_size, config.n_embd) + self.drop = nn.Dropout(config.dropout) + + # Import block class from model_local + from model_local import LocalBlock + self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Feedback modules: one per block + self.feedbacks = nn.ModuleList([ + FeedbackModule(config.n_embd) for _ in range(config.n_layer) + ]) + + self.apply(self._init_weights) + # Match LocalGPT: scale down o_proj and mlp.proj for residual stream stability + for pn, p in self.named_parameters(): + if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"): + nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, LocalLinear)): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + if getattr(m, "bias", None) is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + + def forward_activations(self, idx): + """Forward pass, returning per-block activations h_0 ... h_L.""" + B, T = idx.shape + pos = torch.arange(T, device=idx.device) + h = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) + activations = [h] + for block in self.blocks: + h = block(h) + activations.append(h) + return activations # len = n_layer + 1 + + def logits_from_h(self, h_final): + """h_L → logits.""" + return self.head(self.ln_f(h_final)) + + def compute_targets(self, activations, logits, targets_y, eta_target=0.1): + """Compute per-block targets via difference target propagation. + + target_L = h_L - η * ∂L/∂h_L + target_l = h_l + g_l(target_{l+1}) - g_l(h_{l+1}) + """ + h_L = activations[-1] + # Compute ∂L/∂h_L (only need grad at the top, not full BP) + h_L_for_grad = h_L.detach().requires_grad_(True) + logits_local = self.head(self.ln_f(h_L_for_grad)) + loss = F.cross_entropy(logits_local.view(-1, logits_local.size(-1)), targets_y.view(-1)) + loss.backward() + grad_h_L = h_L_for_grad.grad.detach() + + # Top target + target = h_L.detach() - eta_target * grad_h_L + targets_list = [None] * (self.config.n_layer + 1) + targets_list[-1] = target + + # Propagate backward via feedback modules + for l in range(self.config.n_layer - 1, -1, -1): + h_l = activations[l].detach() + h_lp1 = activations[l + 1].detach() + target_lp1 = targets_list[l + 1] + # Difference target propagation + targets_list[l] = h_l + self.feedbacks[l](target_lp1) - self.feedbacks[l](h_lp1) + + return targets_list + + def reconstruction_loss(self, activations, sigma=0.1): + """Train feedback modules via reconstruction loss (DRL-style with noise). + + For each block l: corrupt h_l, forward through block, reconstruct via g_l. + """ + total_loss = 0.0 + for l in range(self.config.n_layer): + h_l = activations[l].detach() + h_lp1 = activations[l + 1].detach() + # Add noise to h_l + noise = torch.randn_like(h_l) * sigma + h_l_noisy = h_l + noise + # Forward through block (detached, just computing) + with torch.no_grad(): + h_lp1_noisy = self.blocks[l](h_l_noisy) + # Reconstruct via feedback + h_l_recon = self.feedbacks[l](h_lp1_noisy) + # Difference correction: reconstruct the NOISE, not absolute position + recon_target = h_l_noisy + total_loss = total_loss + F.mse_loss(h_l_recon, recon_target) + return total_loss / self.config.n_layer + + def local_forward_loss(self, activations, targets_list): + """Per-block local loss: ||f_l(h_l) - target_{l+1}||². + + Gradients flow within each block (using fused attention backward etc.) + but NOT across blocks (targets are detached). + """ + total_loss = 0.0 + for l in range(self.config.n_layer): + h_l = activations[l].detach() # detach: no inter-block gradient + target_lp1 = targets_list[l + 1].detach() + # Forward through block (WITH gradient for intra-block params) + h_lp1_pred = self.blocks[l](h_l) + # Local loss + total_loss = total_loss + F.mse_loss(h_lp1_pred, target_lp1) + return total_loss / self.config.n_layer + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--run_name", type=str, required=True) + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--data_dir", type=str, default="data/shakespeare_char") + p.add_argument("--out_dir", type=str, default="runs_local") + p.add_argument("--block_size", type=int, default=256) + p.add_argument("--batch_size", type=int, default=64) + p.add_argument("--n_layer", type=int, default=6) + p.add_argument("--n_head", type=int, default=6) + p.add_argument("--n_embd", type=int, default=384) + p.add_argument("--dropout", type=float, default=0.2) + p.add_argument("--max_iters", type=int, default=5000) + p.add_argument("--warmup_iters", type=int, default=100) + p.add_argument("--max_lr", type=float, default=1e-3) + p.add_argument("--min_lr", type=float, default=1e-4) + p.add_argument("--eta_target", type=float, default=0.1, help="target stepsize for top-layer target") + p.add_argument("--sigma_recon", type=float, default=0.1, help="noise std for reconstruction loss") + p.add_argument("--lr_feedback", type=float, default=1e-3, help="LR for feedback modules") + p.add_argument("--eval_interval", type=int, default=250) + p.add_argument("--eval_iters", type=int, default=100) + p.add_argument("--log_interval", type=int, default=50) + p.add_argument("--attn_mode", type=str, default="softmax") + args = p.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + data_dir = Path(args.data_dir) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + vocab_size = meta["vocab_size"] + + run_dir = Path(args.out_dir) / args.run_name + run_dir.mkdir(parents=True, exist_ok=True) + log_path = run_dir / "log.jsonl" + log_path.write_text("") + + cfg = LocalGPTConfig( + block_size=args.block_size, vocab_size=vocab_size, + n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, + dropout=args.dropout, attn_mode=args.attn_mode, + method="bp", # intra-block uses standard autograd (with fused attention) + fuse_attn_local=True, + ste_gelu=True, + ln_mode="center_scale", + ) + model = ReconTransformer(cfg).to(device) + n_params = sum(p.numel() for p in model.parameters()) + + # Separate optimizers for forward and feedback + forward_params = list(model.tok_emb.parameters()) + list(model.pos_emb.parameters()) + \ + list(model.head.parameters()) + list(model.ln_f.parameters()) + for block in model.blocks: + forward_params.extend(block.parameters()) + + feedback_params = list(model.feedbacks.parameters()) + + opt_fwd = torch.optim.AdamW(forward_params, lr=args.max_lr, weight_decay=0.1) + opt_fb = torch.optim.AdamW(feedback_params, lr=args.lr_feedback, weight_decay=0.01) + + t0 = time.time() + + def log(rec): + rec["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(rec) + "\n") + + log({"event": "start", "method": "reconstruction", "params": n_params, "config": vars(args)}) + print(f"[{args.run_name}] recon transformer, params={n_params/1e6:.2f}M") + + def lr_schedule(it): + if it < args.warmup_iters: + return args.max_lr * (it + 1) / (args.warmup_iters + 1) + decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) / + max(1, args.max_iters - args.warmup_iters))) + return args.min_lr + decay * (args.max_lr - args.min_lr) + + @torch.no_grad() + def eval_loss(): + model.eval() + losses = torch.zeros(args.eval_iters) + for k in range(args.eval_iters): + X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device) + acts = model.forward_activations(X) + logits = model.logits_from_h(acts[-1]) + loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) + losses[k] = loss.item() + model.train() + return losses.mean().item() + + model.train() + for it in range(args.max_iters + 1): + lr = lr_schedule(it) + for g in opt_fwd.param_groups: + g["lr"] = lr + + if it % args.eval_interval == 0 or it == args.max_iters: + val = eval_loss() + log({"event": "eval", "iter": it, "val_loss": val, "lr": lr}) + print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}") + + if it == args.max_iters: + break + + X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) + + # Step 1: Forward pass (compute activations) + activations = model.forward_activations(X) + logits = model.logits_from_h(activations[-1]) + ce_loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) + + # Step 2-3: Compute targets via DTP + targets = model.compute_targets(activations, logits, Y, eta_target=args.eta_target) + + # Step 4: Train feedback modules (reconstruction loss) + opt_fb.zero_grad() + recon_loss = model.reconstruction_loss(activations, sigma=args.sigma_recon) + recon_loss.backward() + opt_fb.step() + + # Step 5: Train forward weights (no inter-block BP) + opt_fwd.zero_grad() + + # 5a: Head + ln_f via CE loss on DETACHED h_L (gradient stays at top, no BP into blocks) + h_L_det = activations[-1].detach() + logits_head = model.logits_from_h(h_L_det) + head_loss = F.cross_entropy(logits_head.view(-1, vocab_size), Y.view(-1)) + head_loss.backward() + + # 5b: Block-local target-matching losses + # Block 0: DON'T detach h_0 so embedding gets gradient from block 0's local loss + for l in range(cfg.n_layer): + h_l = activations[l] if l == 0 else activations[l].detach() + target_lp1 = targets[l + 1].detach() + h_lp1_pred = model.blocks[l](h_l) + block_loss = F.mse_loss(h_lp1_pred, target_lp1) + block_loss.backward() + + torch.nn.utils.clip_grad_norm_(forward_params, 1.0) + opt_fwd.step() + + if it % args.log_interval == 0: + log({"event": "step", "iter": it, "ce_loss": ce_loss.item(), + "recon_loss": recon_loss.item(), "head_loss": head_loss.item(), "lr": lr}) + + +if __name__ == "__main__": + main() diff --git a/ep_run/train_stiefel.py b/ep_run/train_stiefel.py new file mode 100644 index 0000000..0b218ff --- /dev/null +++ b/ep_run/train_stiefel.py @@ -0,0 +1,211 @@ +"""Stiefel factored feedback training for local transformer. + +Replaces FA's random B with: δ_l = α_l · (e_L @ C^T) @ U_l^T +where C is fixed row-orthonormal, U_l is per-layer learnable on Stiefel. + +Each block uses fused attention backward + GELU STE + center_scale LN. +Head trained via detached CE loss. Embedding frozen. +g_l reconstruction modules provide local proxy signal for U_l updates. +""" +import argparse +import json +import math +import pickle +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model_local import LocalGPTConfig +from train_recon import ReconTransformer, get_batch, FeedbackModule +from stiefel_feedback import StiefelFeedbackSystem + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--run_name", type=str, required=True) + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--data_dir", type=str, default="data/shakespeare_char") + p.add_argument("--out_dir", type=str, default="runs_local") + p.add_argument("--block_size", type=int, default=256) + p.add_argument("--batch_size", type=int, default=64) + p.add_argument("--n_layer", type=int, default=6) + p.add_argument("--n_head", type=int, default=6) + p.add_argument("--n_embd", type=int, default=384) + p.add_argument("--dropout", type=float, default=0.2) + p.add_argument("--max_iters", type=int, default=5000) + p.add_argument("--warmup_iters", type=int, default=100) + p.add_argument("--max_lr", type=float, default=1e-3) + p.add_argument("--min_lr", type=float, default=1e-4) + p.add_argument("--rank", type=int, default=128) + p.add_argument("--eta_B", type=float, default=3e-5) + p.add_argument("--freeze_fb_steps", type=int, default=200) + p.add_argument("--sigma_recon", type=float, default=0.1) + p.add_argument("--eta_target", type=float, default=0.1) + p.add_argument("--eval_interval", type=int, default=250) + p.add_argument("--eval_iters", type=int, default=100) + p.add_argument("--log_interval", type=int, default=50) + p.add_argument("--attn_mode", type=str, default="softmax") + args = p.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + data_dir = Path(args.data_dir) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + vocab_size = meta["vocab_size"] + + run_dir = Path(args.out_dir) / args.run_name + run_dir.mkdir(parents=True, exist_ok=True) + log_path = run_dir / "log.jsonl" + log_path.write_text("") + + cfg = LocalGPTConfig( + block_size=args.block_size, vocab_size=vocab_size, + n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, + dropout=args.dropout, attn_mode=args.attn_mode, + method="bp", # intra-block standard autograd with fused attention + fuse_attn_local=True, ste_gelu=True, ln_mode="center_scale", + ) + model = ReconTransformer(cfg).to(device) + + # Stiefel feedback system + layer_dims = [args.n_embd] * args.n_layer # each block output is d_model + fb_system = StiefelFeedbackSystem(vocab_size, layer_dims, rank=args.rank).to(device) + + n_params = sum(p.numel() for p in model.parameters()) + n_fb = sum(p.numel() for p in fb_system.parameters()) + + # Optimizers: forward (blocks + head), feedback g_l (reconstruction) + forward_params = list(model.head.parameters()) + list(model.ln_f.parameters()) + for block in model.blocks: + forward_params.extend(block.parameters()) + feedback_g_params = list(model.feedbacks.parameters()) + + opt_fwd = torch.optim.AdamW(forward_params, lr=args.max_lr, weight_decay=0.1) + opt_fb_g = torch.optim.AdamW(feedback_g_params, lr=args.max_lr, weight_decay=0.01) + # U_l and α_l are updated manually via Stiefel retraction, not via optimizer + + t0 = time.time() + + def log(rec): + rec["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(rec) + "\n") + + log({"event": "start", "method": "stiefel_factored", "params": n_params, + "fb_params": n_fb, "rank": args.rank, "config": vars(args)}) + print(f"[{args.run_name}] stiefel factored, params={n_params/1e6:.2f}M, fb={n_fb/1e3:.1f}K, rank={args.rank}") + + def lr_schedule(it): + if it < args.warmup_iters: + return args.max_lr * (it + 1) / (args.warmup_iters + 1) + decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) / + max(1, args.max_iters - args.warmup_iters))) + return args.min_lr + decay * (args.max_lr - args.min_lr) + + @torch.no_grad() + def eval_loss(): + model.eval() + losses = torch.zeros(args.eval_iters) + for k in range(args.eval_iters): + X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device) + acts = model.forward_activations(X) + logits = model.logits_from_h(acts[-1]) + loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) + losses[k] = loss.item() + model.train() + return losses.mean().item() + + model.train() + for it in range(args.max_iters + 1): + lr = lr_schedule(it) + for g in opt_fwd.param_groups: + g["lr"] = lr + + if it % args.eval_interval == 0 or it == args.max_iters: + val = eval_loss() + log({"event": "eval", "iter": it, "val_loss": val, "lr": lr}) + print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}") + + if it == args.max_iters: + break + + X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) + + # 1. Forward pass + activations = model.forward_activations(X) + logits = model.logits_from_h(activations[-1]) + ce_loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) + + # 2. Compute e_L and compress + with torch.no_grad(): + probs = F.softmax(logits.detach(), dim=-1) + onehot = F.one_hot(Y, num_classes=vocab_size).float() + e_L = (probs - onehot) / Y.numel() + c = fb_system.compress_error(e_L) + + # 3. Compute per-layer δ via Stiefel feedback + deltas = fb_system.compute_deltas(c) + + # 4. Train g_l (reconstruction feedback modules) + opt_fb_g.zero_grad() + recon_loss = model.reconstruction_loss(activations, sigma=args.sigma_recon) + recon_loss.backward() + opt_fb_g.step() + + # 5. Get local proxy signals g_hat_l from reconstruction modules + g_hats = [] + for l in range(cfg.n_layer): + with torch.no_grad(): + h_l = activations[l].detach() + h_lp1 = activations[l + 1].detach() + g_hat_l = model.feedbacks[l](h_lp1) - h_l # reconstruction error + g_hats.append(g_hat_l) + + # 6. Update Stiefel feedback (U_l, α_l) + frozen = (it < args.freeze_fb_steps) + fb_diags = fb_system.update_all(g_hats, c, frozen=frozen, eta_B=args.eta_B) + + # 7. Train forward weights via block-local loss using Stiefel δ as targets + opt_fwd.zero_grad() + + # 7a. Head via detached CE + h_L_det = activations[-1].detach() + logits_head = model.logits_from_h(h_L_det) + head_loss = F.cross_entropy(logits_head.view(-1, vocab_size), Y.view(-1)) + head_loss.backward() + + # 7b. Each block: local target = h_l + δ_l (feedback signal as target displacement) + for l in range(cfg.n_layer): + h_l = activations[l] if l == 0 else activations[l].detach() + h_lp1 = activations[l + 1].detach() + # Target for block l's output: current output + δ_l displacement + target_lp1 = h_lp1 - deltas[l].detach() # push toward lower loss + h_lp1_pred = model.blocks[l](h_l) + block_loss = F.mse_loss(h_lp1_pred, target_lp1) + block_loss.backward() + + torch.nn.utils.clip_grad_norm_(forward_params, 1.0) + opt_fwd.step() + + if it % args.log_interval == 0: + fb_info = {} + if not frozen and fb_diags: + fb_info = { + "alpha_mean": sum(d.get("alpha", 0) for d in fb_diags) / len(fb_diags), + "rho_mean": sum(d.get("rho", 0) for d in fb_diags) / len(fb_diags), + "Delta_frob_mean": sum(d.get("Delta_frob", 0) for d in fb_diags) / len(fb_diags), + } + log({"event": "step", "iter": it, "ce_loss": ce_loss.item(), + "recon_loss": recon_loss.item(), "head_loss": head_loss.item(), + "frozen": frozen, **fb_info, "lr": lr}) + + +if __name__ == "__main__": + main() diff --git a/ep_run/verify_aep_manual.py b/ep_run/verify_aep_manual.py new file mode 100644 index 0000000..6c4b403 --- /dev/null +++ b/ep_run/verify_aep_manual.py @@ -0,0 +1,62 @@ +import math, time, torch +import lt_ep_train as LT, holo_ep as H +from test_aselect_deepdive import (manual_nc_jvp_vjp_thick, make_manual_step, + make_tf_step, run_loop_from_step, cosine) +import torch.func as tf + +def cosd(ga, gb, ps): + num=da=db=0.0 + for p in ps: + a=ga.get(id(p)); b=gb.get(id(p)) + if a is None or b is None: continue + num+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum()) + return num/(math.sqrt(da*db)+1e-30) + +def grad_from_a(blk, zs, idx, a): + with torch.enable_grad(): + xin=blk.embed(idx) + f=blk.force(zs.detach(), xin, cg=True) + gs=torch.autograd.grad((a.detach()*f).sum(), blk.block, allow_unused=True) + return {id(p):g for p,g in zip(blk.block,gs) if g is not None} + +dev='cpu' +blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick') +blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 +ck=torch.load('runs/ep_resreg_warm.pt', map_location=dev) +with torch.no_grad(): + for p,s in zip(blk.allp, ck['allp']): p.copy_(s.to(dev)) +print(f"loaded step={ck.get('step')} best={ck.get('best')}", flush=True) + +# 1) Verify direct compile(step over torch.func) FAILS (independent confirmation) +torch.manual_seed(0) +idx,y=LT.get_batch('train',1,256); xin=blk.embed(idx).detach() +zs=LT.relax(blk,xin.clone(),xin,1,0.1); B=zs.size(0); Z0=torch.cat([zs,zs],0) +tf_step=make_tf_step(blk,zs,xin,y,0.02,0.1) +try: + c=torch.compile(tf_step, fullgraph=True); c(Z0) + print("COMPILE compile∘func: UNEXPECTEDLY OK", flush=True) +except Exception as e: + print("COMPILE compile∘func: FAILS ->", type(e).__name__, str(e)[:90], flush=True) + +# 2) one-step manual vs torch.func identity + J accuracy +v=torch.randn_like(Z0)*1e-3 +zbar=0.5*(Z0[:B]+Z0[B:]); zb2=torch.cat([zbar,zbar],0) +_,Jv_ref=tf.jvp(lambda zz: blk.nc_force(zz),(zb2,),(v,)) +JTv_ref=tf.vjp(lambda zz: blk.nc_force(zz),zb2)[1](v)[0] +Jv_m,JTv_m=manual_nc_jvp_vjp_thick(blk,zb2,v) +print(f"manual Jv cos={cosine(Jv_ref,Jv_m):.8f} JTv cos={cosine(JTv_ref,JTv_m):.8f}", flush=True) +with torch.no_grad(): + Zt=tf_step(Z0); Zm=make_manual_step(blk,zs,xin,y,0.02,0.1)(Z0) +print(f"one-step manual vs tf max_abs={float((Zt-Zm).abs().max()):.2e}", flush=True) + +# 3) full a-select + gradient cosine, several configs +for seed,T2,K in [(1,80,10),(2,80,10),(3,160,10)]: + torch.manual_seed(seed) + idx,y=LT.get_batch('train',1,256); xin=blk.embed(idx).detach() + zs=LT.relax(blk,xin.clone(),xin,1,0.1) + t=time.time(); a0,tb=H.holo_a_track(blk,zs,xin,y,0.02,T2,0.1,K=K); sb=time.time()-t + t=time.time(); a1,tm=run_loop_from_step(make_manual_step(blk,zs,xin,y,0.02,0.1),zs,0.02,T2,K=K); sm=time.time()-t + g0=grad_from_a(blk,zs,idx,a0); g1=grad_from_a(blk,zs,idx,a1) + print(f"seed{seed} T2={T2} K={K}: a_cos={cosine(a0,a1):.7f} grad_cos={cosd(g0,g1,blk.block):.7f} " + f"tbest={tb}/{tm} base={sb:.2f}s man={sm:.2f}s spd={sb/sm:.2f}x", flush=True) +print("DONE", flush=True) diff --git a/ep_run/watch_all.sh b/ep_run/watch_all.sh new file mode 100755 index 0000000..a1baec7 --- /dev/null +++ b/ep_run/watch_all.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# Consolidated watcher for the 3 live C512 EP runs. nohup'd so it survives session events. +# Logs each run's latest val-CE line + alive/dead every 10 min to runs/watch_all.log. +R=/home/yurenh2/ept/ep_run/runs +LOG=$R/watch_all.log +while true; do + TS=$(date '+%m-%d %H:%M') + for f in ep_resreg_warm ep_jacreg ep_rr_ajr; do + if pgrep -f "ckpt runs/$f.pt" >/dev/null; then a=ALIVE; else a=DEAD; fi + line=$(grep -iE "val CE" "$R/$f.log" 2>/dev/null | tail -1) + echo "$TS [$a] $f | $line" >> "$LOG" + done + echo "----" >> "$LOG" + sleep 600 +done diff --git a/ep_run/watch_clean.py b/ep_run/watch_clean.py new file mode 100644 index 0000000..3f9f06c --- /dev/null +++ b/ep_run/watch_clean.py @@ -0,0 +1,36 @@ +"""Watch the clean-code re-run: pure EP (ep_clean) + BPTT control (bptt_clean). Fire on a decisive event.""" +import time, os, re +RUNS = [ + ("ep", "/home/yurenh2/ept/ep_run/runs/ep_clean.log", 1646260), + ("bptt", "/home/yurenh2/ept/ep_run/runs/bptt_clean.log", 1646261), +] +def alive(pid): + try: os.kill(pid, 0); return True + except Exception: return False +def latest(log): + try: ls = [l for l in open(log) if l.startswith("step")] + except FileNotFoundError: return None + if not ls: return None + m = re.search(r"step (\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", ls[-1]) + return (int(m.group(1)), float(m.group(2)), float(m.group(3)), ls[-1].strip()) if m else None +def status(): + return "\n".join(f"[{t}] {'ALIVE' if alive(p) else 'DEAD'} | {(latest(l) or [None,None,None,'no steps'])[3]}" for t, l, p in RUNS) +t0 = time.time(); fired = None +while fired is None and time.time() - t0 < 16 * 3600: + for t, l, p in RUNS: + d = latest(l) + if d: + step, val, res, _ = d + if t == "ep": + if res > 0.2 or val > 15: fired = f"EP DIVERGED res={res:.2e} val={val:.2f} step {step} (clean-code baseline of the problem)"; break + if val < 2.20: fired = f"EP reached val {val:.4f} (res {res:.2e}) step {step} — good converged ckpt to probe"; break + if t == "bptt": + if val < 1.95: fired = f"bptt reached GOOD loss val {val:.4f} (res {res:.2e}) step {step} -> probe rho/g_transpose here"; break + if res > 0.25: fired = f"bptt res HIGH res={res:.2e} val={val:.2f} step {step} (BPTT riding NON-converged state?)"; break + if not alive(p): + fired = f"{t} process EXITED; last: {(latest(l) or [None,None,None,'no steps (early crash/OOM?)'])[3]}"; break + if fired: break + time.sleep(180) +print("=== CLEAN-RERUN WATCHER FIRED ===") +print("trigger:", fired or "16h timeout, no decisive event") +print(status()) diff --git a/ep_run/watch_contraction.py b/ep_run/watch_contraction.py new file mode 100644 index 0000000..64f6a9f --- /dev/null +++ b/ep_run/watch_contraction.py @@ -0,0 +1,41 @@ +"""Watcher for the two contraction experiments (c3 + specnorm). Fires (exits) when either run +hits a decisive state: DIVERGED (res>0.2 or val>15), CLEARED the danger zone (step>=10200, res<0.06, +val<2.5 -> survived past the step ~9400 where the unconstrained run blew), or process EXITED.""" +import time, os, re +RUNS = [ + ("c3", "/home/yurenh2/ept/ep_run/runs/ep_c3_scratch.log", 1429784), + ("specnorm", "/home/yurenh2/ept/ep_run/runs/ep_specnorm09_scratch.log",1435898), +] +def alive(pid): + try: os.kill(pid, 0); return True + except Exception: return False +def latest(log): + try: lines = [l for l in open(log) if l.startswith("step")] + except FileNotFoundError: return None + if not lines: return None + m = re.search(r"step (\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", lines[-1]) + if not m: return None + return int(m.group(1)), float(m.group(2)), float(m.group(3)), lines[-1].strip() +def status_all(): + out = [] + for tag, log, pid in RUNS: + d = latest(log) + out.append(f"[{tag}] {'ALIVE' if alive(pid) else 'DEAD'} | {d[3] if d else 'no steps yet'}") + return "\n".join(out) +t0 = time.time(); fired = None +while fired is None and time.time() - t0 < 15 * 3600: + for tag, log, pid in RUNS: + d = latest(log) + if d: + step, val, res, _ = d + if res > 0.2 or val > 15: + fired = f"{tag} DIVERGED (res={res:.2e}, val={val:.2f}) at step {step}"; break + if step >= 10200 and res < 0.06 and val < 2.5: + fired = f"{tag} CLEARED danger zone: step {step}, val {val:.4f}, res {res:.2e} (survived past ~9400)"; break + if not alive(pid): + fired = f"{tag} process EXITED (abort_res / crash / done); last: {d[3] if d else 'no steps'}"; break + if fired: break + time.sleep(300) +print("=== CONTRACTION WATCHER FIRED ===") +print("trigger:", fired if fired else "max wall-time (15h) reached, no decisive event") +print(status_all()) diff --git a/ep_run/watch_hr.py b/ep_run/watch_hr.py new file mode 100644 index 0000000..33a44bb --- /dev/null +++ b/ep_run/watch_hr.py @@ -0,0 +1,33 @@ +"""Watch EP(hr=0.2) [the fix] + BPTT. Fire on: EP diverges (fix failed) / EP reaches good loss (fix worked) / exit.""" +import time, os, re +RUNS = [ + ("ep_hr02", "/home/yurenh2/ept/ep_run/runs/ep_hr02.log", 1684249), + ("bptt", "/home/yurenh2/ept/ep_run/runs/bptt_clean.log", 1646261), +] +def alive(pid): + try: os.kill(pid, 0); return True + except Exception: return False +def latest(log): + try: ls = [l for l in open(log) if l.startswith("step")] + except FileNotFoundError: return None + if not ls: return None + m = re.search(r"step (\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", ls[-1]) + return (int(m.group(1)), float(m.group(2)), float(m.group(3)), ls[-1].strip()) if m else None +def status(): + return "\n".join(f"[{t}] {'ALIVE' if alive(p) else 'DEAD'} | {(latest(l) or [None,None,None,'no steps'])[3]}" for t, l, p in RUNS) +t0 = time.time(); fired = None +while fired is None and time.time() - t0 < 18 * 3600: + for t, l, p in RUNS: + d = latest(l) + if d: + step, val, res, _ = d + if t == "ep_hr02": + if res > 0.2 or val > 15: fired = f"EP(hr=0.2) STILL DIVERGED res={res:.2e} val={val:.2f} step {step} -> hr was not the (only) cause"; break + if val < 2.00: fired = f"EP(hr=0.2) reached GOOD loss val {val:.4f} (res {res:.2e}) step {step} -> the hr fix WORKED (past the old 2.09 wall region)"; break + if t == "bptt": + if val < 1.95: fired = f"bptt reached GOOD loss val {val:.4f} (res {res:.2e}) step {step}"; break + if not alive(p): + fired = f"{t} EXITED; last: {(latest(l) or [None,None,None,'no steps'])[3]}"; break + if fired: break + time.sleep(180) +print("=== HR-FIX WATCHER FIRED ==="); print("trigger:", fired or "18h timeout"); print(status()) diff --git a/ep_run/watch_runs.py b/ep_run/watch_runs.py new file mode 100644 index 0000000..92acf16 --- /dev/null +++ b/ep_run/watch_runs.py @@ -0,0 +1,39 @@ +"""Watch specnorm (contraction test) + bptt_ctrl (premise test). Fires on a decisive event.""" +import time, os, re +RUNS = [ + ("specnorm", "/home/yurenh2/ept/ep_run/runs/ep_specnorm09_scratch.log", 1435898), + ("bptt", "/home/yurenh2/ept/ep_run/runs/bptt_ctrl.log", 1511172), +] +def alive(pid): + try: os.kill(pid, 0); return True + except Exception: return False +def latest(log): + try: ls = [l for l in open(log) if l.startswith("step")] + except FileNotFoundError: return None + if not ls: return None + m = re.search(r"step (\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", ls[-1]) + return (int(m.group(1)), float(m.group(2)), float(m.group(3)), ls[-1].strip()) if m else None +def status(): + o = [] + for t, l, p in RUNS: + d = latest(l); o.append(f"[{t}] {'ALIVE' if alive(p) else 'DEAD'} | {d[3] if d else 'no steps'}") + return "\n".join(o) +t0 = time.time(); fired = None +while fired is None and time.time() - t0 < 15 * 3600: + for t, l, p in RUNS: + d = latest(l) + if d: + step, val, res, _ = d + if t == "specnorm": + if res > 0.2 or val > 15: fired = f"specnorm DIVERGED res={res:.2e} val={val:.2f} step {step}"; break + if step >= 10200 and res < 0.06 and val < 2.5: fired = f"specnorm CLEARED step {step} val {val:.4f} res {res:.2e}"; break + if t == "bptt": + if val < 1.95: fired = f"bptt reached GOOD loss val {val:.4f} (res {res:.2e}) step {step} -> READY to probe rho(S^-1 A) at the good solution"; break + if res > 0.25: fired = f"bptt res HIGH res={res:.2e} val={val:.2f} step {step} (BPTT riding a NON-converged state?)"; break + if not alive(p): + fired = f"{t} process EXITED; last: {d[3] if d else 'no steps (early crash/OOM?)'}"; break + if fired: break + time.sleep(180) +print("=== WATCHER FIRED ===") +print("trigger:", fired or "15h timeout, no decisive event") +print(status()) |
