# Architecture: EP/AEP-trained Equilibrium Transformer — implementation details Based on the actual implementation in `/tmp/lt_ep/lt_ep_train.py` (`EQBlock` + `ep_step`). ## 0. Overview — one unified force, one relaxation The whole block is **one dynamical system**: token state `z ∈ R^{B,T,C}` relaxes to a fixed point under a **single force** that bundles input-clamp + FFN + attention: ``` F(z) = − ∂/∂z [ ½‖z − x_in‖² + E_mem(z) ] + s·( Attn(z) − c·z ) └────── conservative part (symmetric Jacobian) ──┘ └── non-conservative (non-reciprocal) ──┘ input clamp Hopfield memory = FFN causal self-attention + damping ``` - `x_in = tok[idx] + pos` — input embedding, clamped as a boundary condition. - **Forward = relaxation**: `z ← z + ε·F(z)` for T1 steps → fixed point `z*`; read out `logits = z*·W_h`. - Conservative (FFN/clamp) and non-conservative (attention) live in **one force, one relaxation** — the basis of "unified" training. --- ## 1. FFN — standard EP (the clean part) The FFN is rewritten as a **modern Hopfield memory energy**: ``` E_mem(z) = − Σ_{token, m} relu( z · W_m )²_m # W_m ∈ R^{C×M}, M memories ``` Its force `−∂E_mem/∂z = 2·W_m·[relu(zW_m) ⊙ 1_{zW_m>0}]` is exactly a **tied-weight 2-layer MLP (W_in=W_out=W_m) with squared-ReLU** = the FFN. - **Conservative** (scalar energy, symmetric Jacobian) → **standard EP is exact, no correction**. - Verified: FFN-param gradient cosine vs backprop = **1.000** (`lt_ep_ffn.py`). - This is textbook EP / Hopfield — already demonstrated on memristor crossbars in the literature. --- ## 2. Attention — how it is "EP-ified" (the novel part) **Step A — rewrite attention as a FORCE** (not a feedforward layer): tokens relax under it. ``` Attn(z) = [ softmax( Q(z)K(z)ᵀ/√d , causal ) · V(z) ] · W_O Q=zW_Q, K=zW_K, V=zW_V (independent projections ⇒ NON-reciprocal: i→j ≠ j→i) force term = s·( Attn(z) − c·z ) # −c·z damping ⇒ contraction ⇒ a stable fixed point exists ``` **Step B — fix the non-reciprocity bias (AEP correction).** Because Q≠K and V is independent, the attention Jacobian `J_attn` is **asymmetric** — it is not the gradient of any scalar. Plain EP's nudged phase relaxes under `J`, but the correct adjoint needs `Jᵀ`, so plain EP gives a **biased** gradient. AEP adds, in the nudged phase: ``` v = z − z* # deviation from the free equilibrium corr = s·( J_attn·v − J_attnᵀ·v ) # = 2·A_J·v , A_J = (J − Jᵀ)/2 (antisymmetric part) J_attn·v = jvp(Attn, z*, v) # forward-mode (one forward probe) J_attnᵀ·v = vjp(Attn, z*, v) # reverse-mode (one backward probe) f_nudged = F(z) − sign·β·∂C/∂z − clip(corr) ``` Effect: the attention part of the nudged linearization becomes `s·J·v − s·(J−Jᵀ)v = s·Jᵀ·v` — i.e. **J is turned into Jᵀ**, giving the correct adjoint. - The damping `−c·z` is **symmetric** (Jacobian −cI) ⇒ cancels in `J−Jᵀ` ⇒ the correction only sees attention's **non-reciprocal** part. - Verified: attention-param gradient cosine vs backprop = **0.99–1.0** (plain EP: 0.25–0.78). - Hardware note: `jvp/vjp` = two probe directions; **non-reciprocal coupling is exactly what real analog devices have** — AEP removes the usual "symmetric weights" requirement of EP hardware. --- ## 3. End-to-end unified training **One relaxation, one estimator, trains the whole block.** Key fact: > The antisymmetric Jacobian of the **full** force, `A_J`, equals the antisymmetric part of > (conservative + attention) = **attention's antisymmetric part alone** (the conservative FFN/clamp > have symmetric Jacobians → contribute 0 to `A_J`). So **the AEP correction needs to act on the attention term only**; the FFN/clamp ride along in the conservative part and are trained correctly by standard EP — **one relaxation, one correction, trains everything.** **Training step (`ep_step`) = 3 phases + 1 local update:** ``` ① free phase: z* = relax(F, x_in, T1) # to the fixed point ② nudged ±: z₊ = relax( F − β·∂C/∂z − corr , from z*, T2 ) # +β z₋ = relax( F + β·∂C/∂z − corr , from z*, T2 ) # −β (centered EP) ③ adjoint: a = (z₋ − z₊) / (2β) # read from the two nudged equilibria ④ local update: • equilibrium params (W_Q,W_K,W_V,W_O, W_m, embeddings tok/pos): ∇_θ L = ⟨ a , ∂F/∂θ(z*) ⟩ # vector-field EP estimator — one formula for attn+FFN+embed (code: autograd.grad( (a·F(z*,θ)).sum(), θ ) , θ live, z* fixed) • readout W_h (outside the equilibrium): direct local gradient ∂C/∂W_h |_{z*} ``` - Why attention (non-conservative) and FFN (conservative) train under the *same* estimator: `⟨a, ∂F/∂θ⟩` is uniform over all equilibrium params; the AEP correction only modifies `a` (making it the correct adjoint for the non-reciprocal system); the FFN's `∂F/∂θ` is already correct. - Embeddings `tok/pos` enter the force through the clamp `½‖z−x_in‖²` (`∂F/∂x_in = +I`), so the same `⟨a, ∂F/∂θ⟩` yields their gradient. - **Stability (feedback regulation, from FRE-RNN 2508.11659):** each step measure the free-phase residual `res = ‖relax(z*,1)−z*‖/‖z*‖`; if `res` is too large, **increase damping `c`** (lower the spectral radius → keep converging); if very small, relax `c`. This maintains the "free phase has converged" condition (Ernoult 2019: EP ≡ BPTT in the β→0, converged limit) throughout training. **Measured (this implementation):** end-to-end EP trains a char-LM to **val CE 2.95** (random 4.17, backprop on the same architecture 2.10), with **zero non-finite steps** under feedback regulation. --- ## One-line summary > **One energy/force, one relaxation, one local estimator.** FFN = conservative Hopfield energy → > *standard EP* (exact). Attention = a *non-reciprocal force*; AEP turns the nudged-phase `J` into `Jᵀ` > via two probes (jvp/vjp) → exact gradient. Since the full force's antisymmetric part comes only from > attention, **one AEP correction + standard EP train the whole block end-to-end**; the readout trains > directly on the cost; damping + feedback-regulation keep the system convergent. ## Hardware-relevant primitives - **local, no weight transport**: every weight updates from locally-available equilibrium states. - **compute = relaxation to a fixed point**: maps to oscillators / memristor crossbars / optics / Ising. - **two phases, same circuit**: free + nudged differ only by a small output nudge β. - **non-reciprocal coupling OK (a feature)**: AEP handles asymmetric `J`; `jvp`/`vjp` = two probe directions. - **dissipation `c` is a physical knob**: feedback-regulated to keep the system in the convergent regime.