summaryrefslogtreecommitdiff
path: root/docs/method/ARCHITECTURE.md
blob: 213b1925b543cadfdbe801444242828735b9a568 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Architecture: EP/AEP-trained Equilibrium Transformer — implementation details

Based on the actual implementation in `/tmp/lt_ep/lt_ep_train.py` (`EQBlock` + `ep_step`).

## 0. Overview — one unified force, one relaxation

The whole block is **one dynamical system**: token state `z ∈ R^{B,T,C}` relaxes to a fixed
point under a **single force** that bundles input-clamp + FFN + attention:

```
F(z) = − ∂/∂z [ ½‖z − x_in‖²  +  E_mem(z) ]      +   s·( Attn(z) − c·z )
        └────── conservative part (symmetric Jacobian) ──┘    └── non-conservative (non-reciprocal) ──┘
          input clamp        Hopfield memory = FFN              causal self-attention + damping
```
- `x_in = tok[idx] + pos` — input embedding, clamped as a boundary condition.
- **Forward = relaxation**: `z ← z + ε·F(z)` for T1 steps → fixed point `z*`; read out `logits = z*·W_h`.
- Conservative (FFN/clamp) and non-conservative (attention) live in **one force, one relaxation** — the basis of "unified" training.

---

## 1. FFN — standard EP (the clean part)

The FFN is rewritten as a **modern Hopfield memory energy**:
```
E_mem(z) = − Σ_{token, m} relu( z · W_m )²_m        # W_m ∈ R^{C×M}, M memories
```
Its force `−∂E_mem/∂z = 2·W_m·[relu(zW_m) ⊙ 1_{zW_m>0}]` is exactly a **tied-weight 2-layer MLP
(W_in=W_out=W_m) with squared-ReLU** = the FFN.

- **Conservative** (scalar energy, symmetric Jacobian) → **standard EP is exact, no correction**.
- Verified: FFN-param gradient cosine vs backprop = **1.000** (`lt_ep_ffn.py`).
- This is textbook EP / Hopfield — already demonstrated on memristor crossbars in the literature.

---

## 2. Attention — how it is "EP-ified" (the novel part)

**Step A — rewrite attention as a FORCE** (not a feedforward layer): tokens relax under it.
```
Attn(z) = [ softmax( Q(z)K(z)ᵀ/√d , causal ) · V(z) ] · W_O
          Q=zW_Q, K=zW_K, V=zW_V    (independent projections ⇒ NON-reciprocal: i→j ≠ j→i)
force term = s·( Attn(z) − c·z )      # −c·z damping ⇒ contraction ⇒ a stable fixed point exists
```

**Step B — fix the non-reciprocity bias (AEP correction).** Because Q≠K and V is independent,
the attention Jacobian `J_attn` is **asymmetric** — it is not the gradient of any scalar. Plain EP's
nudged phase relaxes under `J`, but the correct adjoint needs `Jᵀ`, so plain EP gives a **biased**
gradient. AEP adds, in the nudged phase:
```
v   = z − z*                                  # deviation from the free equilibrium
corr = s·( J_attn·v − J_attnᵀ·v )             # = 2·A_J·v ,  A_J = (J − Jᵀ)/2  (antisymmetric part)
       J_attn·v  = jvp(Attn, z*, v)           # forward-mode  (one forward probe)
       J_attnᵀ·v = vjp(Attn, z*, v)           # reverse-mode  (one backward probe)
f_nudged = F(z) − sign·β·∂C/∂z  −  clip(corr)
```
Effect: the attention part of the nudged linearization becomes `s·J·v − s·(J−Jᵀ)v = s·Jᵀ·v`
— i.e. **J is turned into Jᵀ**, giving the correct adjoint.
- The damping `−c·z` is **symmetric** (Jacobian −cI) ⇒ cancels in `J−Jᵀ` ⇒ the correction only
  sees attention's **non-reciprocal** part.
- Verified: attention-param gradient cosine vs backprop = **0.99–1.0** (plain EP: 0.25–0.78).
- Hardware note: `jvp/vjp` = two probe directions; **non-reciprocal coupling is exactly what real
  analog devices have** — AEP removes the usual "symmetric weights" requirement of EP hardware.

---

## 3. End-to-end unified training

**One relaxation, one estimator, trains the whole block.** Key fact:

> The antisymmetric Jacobian of the **full** force, `A_J`, equals the antisymmetric part of
> (conservative + attention) = **attention's antisymmetric part alone** (the conservative FFN/clamp
> have symmetric Jacobians → contribute 0 to `A_J`).

So **the AEP correction needs to act on the attention term only**; the FFN/clamp ride along in the
conservative part and are trained correctly by standard EP — **one relaxation, one correction, trains
everything.**

**Training step (`ep_step`) = 3 phases + 1 local update:**
```
① free phase:    z* = relax(F, x_in, T1)                              # to the fixed point
② nudged ±:      z₊ = relax( F − β·∂C/∂z − corr , from z*, T2 )       # +β
                 z₋ = relax( F + β·∂C/∂z − corr , from z*, T2 )       # −β  (centered EP)
③ adjoint:       a  = (z₋ − z₊) / (2β)                                # read from the two nudged equilibria
④ local update:
   • equilibrium params (W_Q,W_K,W_V,W_O, W_m, embeddings tok/pos):
        ∇_θ L = ⟨ a , ∂F/∂θ(z*) ⟩          # vector-field EP estimator — one formula for attn+FFN+embed
        (code:  autograd.grad( (a·F(z*,θ)).sum(), θ ) ,  θ live, z* fixed)
   • readout W_h (outside the equilibrium): direct local gradient  ∂C/∂W_h |_{z*}
```
- Why attention (non-conservative) and FFN (conservative) train under the *same* estimator:
  `⟨a, ∂F/∂θ⟩` is uniform over all equilibrium params; the AEP correction only modifies `a` (making it
  the correct adjoint for the non-reciprocal system); the FFN's `∂F/∂θ` is already correct.
- Embeddings `tok/pos` enter the force through the clamp `½‖z−x_in‖²` (`∂F/∂x_in = +I`), so the same
  `⟨a, ∂F/∂θ⟩` yields their gradient.
- **Stability (feedback regulation, from FRE-RNN 2508.11659):** each step measure the free-phase
  residual `res = ‖relax(z*,1)−z*‖/‖z*‖`; if `res` is too large, **increase damping `c`** (lower the
  spectral radius → keep converging); if very small, relax `c`. This maintains the "free phase has
  converged" condition (Ernoult 2019: EP ≡ BPTT in the β→0, converged limit) throughout training.

**Measured (this implementation):** end-to-end EP trains a char-LM to **val CE 2.95** (random 4.17,
backprop on the same architecture 2.10), with **zero non-finite steps** under feedback regulation.

---

## One-line summary
> **One energy/force, one relaxation, one local estimator.** FFN = conservative Hopfield energy →
> *standard EP* (exact). Attention = a *non-reciprocal force*; AEP turns the nudged-phase `J` into `Jᵀ`
> via two probes (jvp/vjp) → exact gradient. Since the full force's antisymmetric part comes only from
> attention, **one AEP correction + standard EP train the whole block end-to-end**; the readout trains
> directly on the cost; damping + feedback-regulation keep the system convergent.

## Hardware-relevant primitives
- **local, no weight transport**: every weight updates from locally-available equilibrium states.
- **compute = relaxation to a fixed point**: maps to oscillators / memristor crossbars / optics / Ising.
- **two phases, same circuit**: free + nudged differ only by a small output nudge β.
- **non-reciprocal coupling OK (a feature)**: AEP handles asymmetric `J`; `jvp`/`vjp` = two probe directions.
- **dissipation `c` is a physical knob**: feedback-regulated to keep the system in the convergent regime.