summaryrefslogtreecommitdiff
path: root/ONBOARDING.md
blob: dfa758b12137357aa55c462559b7d1685ad208af (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# ept — onboarding for a new collaborator (algorithm / experiments side)

*Single entry point, kept current. Deeper docs linked at the bottom. — 2026-07-03*

## 1. What this project is (in three sentences)
We train a **transformer as a fixed-point (equilibrium) system with Equilibrium Propagation (EP) — no
backpropagation**. The forward pass is a damped relaxation `z ← z + ε·F(z)` that settles to a fixed point
`z*`; the weight update is **local**, computed from the contrast between a *free* settle and a slightly
*nudged* settle (no backward pass, no stored activations, no weight transport). This is exactly the
computation analog in-memory hardware does natively — so the north star is a **backprop-free, on-chip-trainable
path to language models**, with the GPU work here de-risking the algorithm before hardware.

## 2. Architecture (current — the `thick` block in `ep_run/lt_ep_train.py`)
One block = one dynamical system on the token state `z ∈ R^{B,T,C}`:
```
F(z) = −(z − x_in)  +  Attn(LN(z))  +  FFN(LN(z))  −  c·z
        └ input clamp ┘   └ causal softmax self-attn ┘  └ untied 4×GELU FFN ┘  └ damping (contraction) ┘
```
- `x_in = tok[idx] + pos`, clamped as a boundary condition. Forward = relax T1 steps → `z*`; readout `logits = z*·W_h`.
- **This is a standard pre-LN transformer block run as a Deep-Equilibrium fixed point**, with a `−c·z` leak for contraction.
- It is **non-conservative**: the attention (untied Q/K/V/O) is not the gradient of any energy → the Jacobian is
  non-normal (we measure `|Jv−Jᵀv|/|Jv| ≈ 1.4`). *That non-conservativity is the source of both the expressivity and
  the central difficulty (below).* (An older energy-formulation variant is in `docs/method/ARCHITECTURE.md`, now superseded.)
- **EP training loop** (`ep_step`): free settle → *nudged* settle (output pulled toward the target by `β`) with an
  **AsymEP correction** (an antisymmetric-Jacobian term that makes the gradient *exact* for the non-conservative
  operator); the update is the state-difference contracted with `∂F/∂θ`. The nudged phase uses a fast
  adaptive-`T2` "holo a-select" estimator (`holo_ep.py`).

## 3. Where we are (results, C512, TinyStories-BPE)
- **EP gradient ≈ exact BPTT gradient** (cosine ≈ 0.92–0.99 per component when the free phase is converged) — the
  learning rule is validated, not approximate.
- **Best val cross-entropy 1.9313** (vs a same-parameter BP transformer ~1.79); generates coherent children's stories.
- The recipe **trains stably and matches/approaches BP** at this scale. Model = C512 / H16 / T256, damped DEQ block.
- ⚠️ **The 1.93 number is warm-started** from a stable early checkpoint (`s2000`); a single from-scratch run currently
  plateaus at **~2.10** (see §5, the crux).

## 4. The one hard problem (and the paper it spun off)
The binding constraint is **NOT the gradient** — it's **forward fixed-point STABILITY during training**. As
optimization makes attention more expressive/non-conservative, the operator loses contraction, a complex-eigenvalue
pair of its Jacobian crosses the imaginary axis (**a supercritical Hopf bifurcation**), the relaxation stops
converging (residual → 0.1+), and training breaks. Controls that hold it: **`resreg`** (penalize the T1 residual),
**`jacreg`** (penalize the Jacobian norm), and the new **`eigreg`** (leading-abscissa / log-norm control, §5).
> This stability question generalized into a **standalone paper** — *"Dynamics and Convergence of Equilibrium
> Learning"* (the report we shared with Ben Scellier is that spin-off, in `/home/yurenh2/aep-dynamics/`): the Hopf +
> a leading-spectral-signal cure, shown across MLP/CNN/RNN and across learning rules (EP and DEQ/RBP). ept is the
> language-model-scale instance of the same phenomenon.

## 5. Open problems — where you can plug in (ranked)
1. **★ Crack from-scratch below 2.0 (the crux).** We *ultimately need* from-scratch (no magic warm checkpoint) for a
   real / hardware result. Diagnosis (via the new `--fingerprint`): the warm source `s2000` is a **deeply contractive**
   operator (numerical abscissa −10) with a well-aligned EP gradient; a from-scratch plateau operator sits **near the
   Hopf boundary** (abscissa +1.11) with a modestly worse gradient — and *training drifts the operator toward the
   boundary as it learns* (val 3.16→2.24 tracks abscissa −10→+1.11). **Hypothesis to test:** hold the operator
   deeply-contractive from scratch with `--eigreg` (leading-abscissa control) → crack the plateau without a warm start.
   Tools are built and default-off: `diag_cos.py` (`--diag_cos N`, `--fingerprint`), `eig_control.py` (`--eigreg`).
2. **Scaling** to hundreds-of-M / small-LLM (gated on cloud compute — a Scellier/AWS path is in progress).
3. **Speed** (`ep_run/profile_ep.py`, `cos_sweep.py`): the holo a-select is ~56% of the step; `t2sel` is a
   cosine-preserving speed lever (160→80 ≈ 1.8× free); multi-GPU data-parallel EP is untried.
4. **Analog realism**: device noise / low-bit quantization / asymmetric update in the simulator (not yet added; the
   Yu-Neng Wang hardware conversation is about exactly this device model).

## 6. Codebase map (`ep_run/`)
- **`lt_ep_train.py`** — everything: the block, `ep_step` (EP training), `bptt_step` (exact-gradient control),
  `relax`, `evaluate`, the residual/jacreg controllers, the training loop. The one file to read first.
- **`holo_ep.py`** — the adaptive-T2 nudged-phase estimator (`holo_a_select`, `holo_a_track`).
- **`diag_cos.py`** (new) — `cos(EP, BPTT)` trajectory + operator `fingerprint` (res / cos / numerical-abscissa / val).
- **`eig_control.py`** (new) — the `--eigreg` leading-abscissa control (power-iteration, scalable, analog-compatible).
- `eig_probe.py`, `cos_sweep.py`, `profile_ep.py`, `bp_transformer.py` (BP baseline) — probes / baselines.
- `data/` (TinyStories-BPE, ~712M) and `runs/` (~8G checkpoints) — **git-ignored; get these separately.**

## 7. How to run
Canonical C512 recipe (one block, EP, ~holo fast path):
```
python3 lt_ep_train.py --mode ep --attn_mode thick --B 24 --C 512 --H 16 --T 256 --c 1.0 \
  --jacreg 0.1 --jr_floor 0.1 --jr_max 16 --holo 2 --hr 0.02 --t2sel 80 --track --pema 0.999 \
  --t1max 150 --res_est 1e-4 --resreg 0.2 --qknorm --T1 150 --T2 20 --lr 6e-4 --wsd 0.25 \
  --steps 32000 --data data/tinystories_bpe --ckpt runs/myrun.pt --state runs/myrun.state
```
Diagnostics: add `--diag_cos 500` (log cos-to-BPTT over training) · `--init_ckpt <ckpt> --fingerprint` (print an
operator's 4-D fingerprint) · `--eigreg 0.1 --eig_margin 1.0` (leading-abscissa control, alt to `--jacreg`).
BP baseline (fair control): `--mode bptt`. **All experiment processes must use `nohup`.**

## 8. Deeper docs (organized under `docs/`)
- **`docs/method/`** — `METHODS.md`, `EP_DERIVATION.md` (the EP/AsymEP gradient derivation), `ARCHITECTURE.md`
  (implementation detail; older energy-formulation, partly superseded by §2 above), `READING.md`.
- **`docs/campaign/`** — `FINDINGS.md` (running log of what worked / didn't) + the full plateau history:
  `C512_PLATEAU_CAMPAIGN.md`, `C512_ROUND2_ABCD.md`, `EP_BELOW210_DIAGNOSIS_FIX.md`,
  `SESSION_2026-06-24_HOPF_DIAGNOSIS_RESREG_FIX.md` (the Hopf diagnosis + resreg fix).
- **`docs/hardware/`** — `SCALING_AND_HARDWARE_PLAN.md` (scaling + analog end goal), `COLLABORATOR_BRIEF.md`
  (hardware-collaborator one-pager), `HW_RESEARCH_FINDINGS.md`, the physics Q&A docs.
- **`docs/outreach/`** — `OUTREACH_TARGETS.md`, `SCELLIER_OUTREACH.md`, `EMAIL_DRAFT_BEN.md`.
- **`docs/paper/PAPER_A_OUTLINE.md`** — the ept paper outline. The dynamics spin-off lives in `../aep-dynamics/`.

*Repo layout:* `ep_run/` = code (start at `lt_ep_train.py`) · `docs/` = the above · `scripts/` = standalone
analysis/probe scripts · `assets/` = PDFs/figures · `refs/` = external paper texts · `archive/` = stale snapshots.