summaryrefslogtreecommitdiff
path: root/ONBOARDING.md
diff options
context:
space:
mode:
Diffstat (limited to 'ONBOARDING.md')
-rw-r--r--ONBOARDING.md95
1 files changed, 95 insertions, 0 deletions
diff --git a/ONBOARDING.md b/ONBOARDING.md
new file mode 100644
index 0000000..dfa758b
--- /dev/null
+++ b/ONBOARDING.md
@@ -0,0 +1,95 @@
+# ept — onboarding for a new collaborator (algorithm / experiments side)
+
+*Single entry point, kept current. Deeper docs linked at the bottom. — 2026-07-03*
+
+## 1. What this project is (in three sentences)
+We train a **transformer as a fixed-point (equilibrium) system with Equilibrium Propagation (EP) — no
+backpropagation**. The forward pass is a damped relaxation `z ← z + ε·F(z)` that settles to a fixed point
+`z*`; the weight update is **local**, computed from the contrast between a *free* settle and a slightly
+*nudged* settle (no backward pass, no stored activations, no weight transport). This is exactly the
+computation analog in-memory hardware does natively — so the north star is a **backprop-free, on-chip-trainable
+path to language models**, with the GPU work here de-risking the algorithm before hardware.
+
+## 2. Architecture (current — the `thick` block in `ep_run/lt_ep_train.py`)
+One block = one dynamical system on the token state `z ∈ R^{B,T,C}`:
+```
+F(z) = −(z − x_in) + Attn(LN(z)) + FFN(LN(z)) − c·z
+ └ input clamp ┘ └ causal softmax self-attn ┘ └ untied 4×GELU FFN ┘ └ damping (contraction) ┘
+```
+- `x_in = tok[idx] + pos`, clamped as a boundary condition. Forward = relax T1 steps → `z*`; readout `logits = z*·W_h`.
+- **This is a standard pre-LN transformer block run as a Deep-Equilibrium fixed point**, with a `−c·z` leak for contraction.
+- It is **non-conservative**: the attention (untied Q/K/V/O) is not the gradient of any energy → the Jacobian is
+ non-normal (we measure `|Jv−Jᵀv|/|Jv| ≈ 1.4`). *That non-conservativity is the source of both the expressivity and
+ the central difficulty (below).* (An older energy-formulation variant is in `docs/method/ARCHITECTURE.md`, now superseded.)
+- **EP training loop** (`ep_step`): free settle → *nudged* settle (output pulled toward the target by `β`) with an
+ **AsymEP correction** (an antisymmetric-Jacobian term that makes the gradient *exact* for the non-conservative
+ operator); the update is the state-difference contracted with `∂F/∂θ`. The nudged phase uses a fast
+ adaptive-`T2` "holo a-select" estimator (`holo_ep.py`).
+
+## 3. Where we are (results, C512, TinyStories-BPE)
+- **EP gradient ≈ exact BPTT gradient** (cosine ≈ 0.92–0.99 per component when the free phase is converged) — the
+ learning rule is validated, not approximate.
+- **Best val cross-entropy 1.9313** (vs a same-parameter BP transformer ~1.79); generates coherent children's stories.
+- The recipe **trains stably and matches/approaches BP** at this scale. Model = C512 / H16 / T256, damped DEQ block.
+- ⚠️ **The 1.93 number is warm-started** from a stable early checkpoint (`s2000`); a single from-scratch run currently
+ plateaus at **~2.10** (see §5, the crux).
+
+## 4. The one hard problem (and the paper it spun off)
+The binding constraint is **NOT the gradient** — it's **forward fixed-point STABILITY during training**. As
+optimization makes attention more expressive/non-conservative, the operator loses contraction, a complex-eigenvalue
+pair of its Jacobian crosses the imaginary axis (**a supercritical Hopf bifurcation**), the relaxation stops
+converging (residual → 0.1+), and training breaks. Controls that hold it: **`resreg`** (penalize the T1 residual),
+**`jacreg`** (penalize the Jacobian norm), and the new **`eigreg`** (leading-abscissa / log-norm control, §5).
+> This stability question generalized into a **standalone paper** — *"Dynamics and Convergence of Equilibrium
+> Learning"* (the report we shared with Ben Scellier is that spin-off, in `/home/yurenh2/aep-dynamics/`): the Hopf +
+> a leading-spectral-signal cure, shown across MLP/CNN/RNN and across learning rules (EP and DEQ/RBP). ept is the
+> language-model-scale instance of the same phenomenon.
+
+## 5. Open problems — where you can plug in (ranked)
+1. **★ Crack from-scratch below 2.0 (the crux).** We *ultimately need* from-scratch (no magic warm checkpoint) for a
+ real / hardware result. Diagnosis (via the new `--fingerprint`): the warm source `s2000` is a **deeply contractive**
+ operator (numerical abscissa −10) with a well-aligned EP gradient; a from-scratch plateau operator sits **near the
+ Hopf boundary** (abscissa +1.11) with a modestly worse gradient — and *training drifts the operator toward the
+ boundary as it learns* (val 3.16→2.24 tracks abscissa −10→+1.11). **Hypothesis to test:** hold the operator
+ deeply-contractive from scratch with `--eigreg` (leading-abscissa control) → crack the plateau without a warm start.
+ Tools are built and default-off: `diag_cos.py` (`--diag_cos N`, `--fingerprint`), `eig_control.py` (`--eigreg`).
+2. **Scaling** to hundreds-of-M / small-LLM (gated on cloud compute — a Scellier/AWS path is in progress).
+3. **Speed** (`ep_run/profile_ep.py`, `cos_sweep.py`): the holo a-select is ~56% of the step; `t2sel` is a
+ cosine-preserving speed lever (160→80 ≈ 1.8× free); multi-GPU data-parallel EP is untried.
+4. **Analog realism**: device noise / low-bit quantization / asymmetric update in the simulator (not yet added; the
+ Yu-Neng Wang hardware conversation is about exactly this device model).
+
+## 6. Codebase map (`ep_run/`)
+- **`lt_ep_train.py`** — everything: the block, `ep_step` (EP training), `bptt_step` (exact-gradient control),
+ `relax`, `evaluate`, the residual/jacreg controllers, the training loop. The one file to read first.
+- **`holo_ep.py`** — the adaptive-T2 nudged-phase estimator (`holo_a_select`, `holo_a_track`).
+- **`diag_cos.py`** (new) — `cos(EP, BPTT)` trajectory + operator `fingerprint` (res / cos / numerical-abscissa / val).
+- **`eig_control.py`** (new) — the `--eigreg` leading-abscissa control (power-iteration, scalable, analog-compatible).
+- `eig_probe.py`, `cos_sweep.py`, `profile_ep.py`, `bp_transformer.py` (BP baseline) — probes / baselines.
+- `data/` (TinyStories-BPE, ~712M) and `runs/` (~8G checkpoints) — **git-ignored; get these separately.**
+
+## 7. How to run
+Canonical C512 recipe (one block, EP, ~holo fast path):
+```
+python3 lt_ep_train.py --mode ep --attn_mode thick --B 24 --C 512 --H 16 --T 256 --c 1.0 \
+ --jacreg 0.1 --jr_floor 0.1 --jr_max 16 --holo 2 --hr 0.02 --t2sel 80 --track --pema 0.999 \
+ --t1max 150 --res_est 1e-4 --resreg 0.2 --qknorm --T1 150 --T2 20 --lr 6e-4 --wsd 0.25 \
+ --steps 32000 --data data/tinystories_bpe --ckpt runs/myrun.pt --state runs/myrun.state
+```
+Diagnostics: add `--diag_cos 500` (log cos-to-BPTT over training) · `--init_ckpt <ckpt> --fingerprint` (print an
+operator's 4-D fingerprint) · `--eigreg 0.1 --eig_margin 1.0` (leading-abscissa control, alt to `--jacreg`).
+BP baseline (fair control): `--mode bptt`. **All experiment processes must use `nohup`.**
+
+## 8. Deeper docs (organized under `docs/`)
+- **`docs/method/`** — `METHODS.md`, `EP_DERIVATION.md` (the EP/AsymEP gradient derivation), `ARCHITECTURE.md`
+ (implementation detail; older energy-formulation, partly superseded by §2 above), `READING.md`.
+- **`docs/campaign/`** — `FINDINGS.md` (running log of what worked / didn't) + the full plateau history:
+ `C512_PLATEAU_CAMPAIGN.md`, `C512_ROUND2_ABCD.md`, `EP_BELOW210_DIAGNOSIS_FIX.md`,
+ `SESSION_2026-06-24_HOPF_DIAGNOSIS_RESREG_FIX.md` (the Hopf diagnosis + resreg fix).
+- **`docs/hardware/`** — `SCALING_AND_HARDWARE_PLAN.md` (scaling + analog end goal), `COLLABORATOR_BRIEF.md`
+ (hardware-collaborator one-pager), `HW_RESEARCH_FINDINGS.md`, the physics Q&A docs.
+- **`docs/outreach/`** — `OUTREACH_TARGETS.md`, `SCELLIER_OUTREACH.md`, `EMAIL_DRAFT_BEN.md`.
+- **`docs/paper/PAPER_A_OUTLINE.md`** — the ept paper outline. The dynamics spin-off lives in `../aep-dynamics/`.
+
+*Repo layout:* `ep_run/` = code (start at `lt_ep_train.py`) · `docs/` = the above · `scripts/` = standalone
+analysis/probe scripts · `assets/` = PDFs/figures · `refs/` = external paper texts · `archive/` = stale snapshots.