1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
|
# Architecture: EP/AEP-trained Equilibrium Transformer — implementation details
Based on the actual implementation in `/tmp/lt_ep/lt_ep_train.py` (`EQBlock` + `ep_step`).
## 0. Overview — one unified force, one relaxation
The whole block is **one dynamical system**: token state `z ∈ R^{B,T,C}` relaxes to a fixed
point under a **single force** that bundles input-clamp + FFN + attention:
```
F(z) = − ∂/∂z [ ½‖z − x_in‖² + E_mem(z) ] + s·( Attn(z) − c·z )
└────── conservative part (symmetric Jacobian) ──┘ └── non-conservative (non-reciprocal) ──┘
input clamp Hopfield memory = FFN causal self-attention + damping
```
- `x_in = tok[idx] + pos` — input embedding, clamped as a boundary condition.
- **Forward = relaxation**: `z ← z + ε·F(z)` for T1 steps → fixed point `z*`; read out `logits = z*·W_h`.
- Conservative (FFN/clamp) and non-conservative (attention) live in **one force, one relaxation** — the basis of "unified" training.
---
## 1. FFN — standard EP (the clean part)
The FFN is rewritten as a **modern Hopfield memory energy**:
```
E_mem(z) = − Σ_{token, m} relu( z · W_m )²_m # W_m ∈ R^{C×M}, M memories
```
Its force `−∂E_mem/∂z = 2·W_m·[relu(zW_m) ⊙ 1_{zW_m>0}]` is exactly a **tied-weight 2-layer MLP
(W_in=W_out=W_m) with squared-ReLU** = the FFN.
- **Conservative** (scalar energy, symmetric Jacobian) → **standard EP is exact, no correction**.
- Verified: FFN-param gradient cosine vs backprop = **1.000** (`lt_ep_ffn.py`).
- This is textbook EP / Hopfield — already demonstrated on memristor crossbars in the literature.
---
## 2. Attention — how it is "EP-ified" (the novel part)
**Step A — rewrite attention as a FORCE** (not a feedforward layer): tokens relax under it.
```
Attn(z) = [ softmax( Q(z)K(z)ᵀ/√d , causal ) · V(z) ] · W_O
Q=zW_Q, K=zW_K, V=zW_V (independent projections ⇒ NON-reciprocal: i→j ≠ j→i)
force term = s·( Attn(z) − c·z ) # −c·z damping ⇒ contraction ⇒ a stable fixed point exists
```
**Step B — fix the non-reciprocity bias (AEP correction).** Because Q≠K and V is independent,
the attention Jacobian `J_attn` is **asymmetric** — it is not the gradient of any scalar. Plain EP's
nudged phase relaxes under `J`, but the correct adjoint needs `Jᵀ`, so plain EP gives a **biased**
gradient. AEP adds, in the nudged phase:
```
v = z − z* # deviation from the free equilibrium
corr = s·( J_attn·v − J_attnᵀ·v ) # = 2·A_J·v , A_J = (J − Jᵀ)/2 (antisymmetric part)
J_attn·v = jvp(Attn, z*, v) # forward-mode (one forward probe)
J_attnᵀ·v = vjp(Attn, z*, v) # reverse-mode (one backward probe)
f_nudged = F(z) − sign·β·∂C/∂z − clip(corr)
```
Effect: the attention part of the nudged linearization becomes `s·J·v − s·(J−Jᵀ)v = s·Jᵀ·v`
— i.e. **J is turned into Jᵀ**, giving the correct adjoint.
- The damping `−c·z` is **symmetric** (Jacobian −cI) ⇒ cancels in `J−Jᵀ` ⇒ the correction only
sees attention's **non-reciprocal** part.
- Verified: attention-param gradient cosine vs backprop = **0.99–1.0** (plain EP: 0.25–0.78).
- Hardware note: `jvp/vjp` = two probe directions; **non-reciprocal coupling is exactly what real
analog devices have** — AEP removes the usual "symmetric weights" requirement of EP hardware.
---
## 3. End-to-end unified training
**One relaxation, one estimator, trains the whole block.** Key fact:
> The antisymmetric Jacobian of the **full** force, `A_J`, equals the antisymmetric part of
> (conservative + attention) = **attention's antisymmetric part alone** (the conservative FFN/clamp
> have symmetric Jacobians → contribute 0 to `A_J`).
So **the AEP correction needs to act on the attention term only**; the FFN/clamp ride along in the
conservative part and are trained correctly by standard EP — **one relaxation, one correction, trains
everything.**
**Training step (`ep_step`) = 3 phases + 1 local update:**
```
① free phase: z* = relax(F, x_in, T1) # to the fixed point
② nudged ±: z₊ = relax( F − β·∂C/∂z − corr , from z*, T2 ) # +β
z₋ = relax( F + β·∂C/∂z − corr , from z*, T2 ) # −β (centered EP)
③ adjoint: a = (z₋ − z₊) / (2β) # read from the two nudged equilibria
④ local update:
• equilibrium params (W_Q,W_K,W_V,W_O, W_m, embeddings tok/pos):
∇_θ L = ⟨ a , ∂F/∂θ(z*) ⟩ # vector-field EP estimator — one formula for attn+FFN+embed
(code: autograd.grad( (a·F(z*,θ)).sum(), θ ) , θ live, z* fixed)
• readout W_h (outside the equilibrium): direct local gradient ∂C/∂W_h |_{z*}
```
- Why attention (non-conservative) and FFN (conservative) train under the *same* estimator:
`⟨a, ∂F/∂θ⟩` is uniform over all equilibrium params; the AEP correction only modifies `a` (making it
the correct adjoint for the non-reciprocal system); the FFN's `∂F/∂θ` is already correct.
- Embeddings `tok/pos` enter the force through the clamp `½‖z−x_in‖²` (`∂F/∂x_in = +I`), so the same
`⟨a, ∂F/∂θ⟩` yields their gradient.
- **Stability (feedback regulation, from FRE-RNN 2508.11659):** each step measure the free-phase
residual `res = ‖relax(z*,1)−z*‖/‖z*‖`; if `res` is too large, **increase damping `c`** (lower the
spectral radius → keep converging); if very small, relax `c`. This maintains the "free phase has
converged" condition (Ernoult 2019: EP ≡ BPTT in the β→0, converged limit) throughout training.
**Measured (this implementation):** end-to-end EP trains a char-LM to **val CE 2.95** (random 4.17,
backprop on the same architecture 2.10), with **zero non-finite steps** under feedback regulation.
---
## One-line summary
> **One energy/force, one relaxation, one local estimator.** FFN = conservative Hopfield energy →
> *standard EP* (exact). Attention = a *non-reciprocal force*; AEP turns the nudged-phase `J` into `Jᵀ`
> via two probes (jvp/vjp) → exact gradient. Since the full force's antisymmetric part comes only from
> attention, **one AEP correction + standard EP train the whole block end-to-end**; the readout trains
> directly on the cost; damping + feedback-regulation keep the system convergent.
## Hardware-relevant primitives
- **local, no weight transport**: every weight updates from locally-available equilibrium states.
- **compute = relaxation to a fixed point**: maps to oscillators / memristor crossbars / optics / Ising.
- **two phases, same circuit**: free + nudged differ only by a small output nudge β.
- **non-reciprocal coupling OK (a feature)**: AEP handles asymmetric `J`; `jvp`/`vjp` = two probe directions.
- **dissipation `c` is a physical knob**: feedback-regulated to keep the system in the convergent regime.
|