diff options
Diffstat (limited to 'docs/method/EP_DERIVATION.md')
| -rw-r--r-- | docs/method/EP_DERIVATION.md | 215 |
1 files changed, 215 insertions, 0 deletions
diff --git a/docs/method/EP_DERIVATION.md b/docs/method/EP_DERIVATION.md new file mode 100644 index 0000000..9adcc38 --- /dev/null +++ b/docs/method/EP_DERIVATION.md @@ -0,0 +1,215 @@ +# 从标准 EP 到当前版本 —— 逐层修改的因果路径 + +伴随文档:`METHODS.md`(按主题组织的完整方法) / `FINDINGS.md`(项目时间线)。 +本文档是**差分视图**:从教科书 EP 出发,每一层修改对应标准 EP 的**一条隐含前提被 +transformer 破坏**,记录"标准怎么做 → 为什么失效 → 我们怎么改 → 代码在哪"。 + +--- + +## 0. 标准 EP(Scellier–Bengio 2017)与它的四条隐含前提 + +标准 EP 在一个**能量函数** `E(z, θ)` 上运行: + +- **自由相**:动力学 `ż = −∇_z E`,relax 到能量极小 `z*`(自由不动点)。 +- **钳制相**:把损失以强度 β 加进能量,`E_β = E + β·ℓ(z)`,relax 到 `z_β`。 +- **梯度**:`∂L/∂θ ≈ (1/β)[ ∂E/∂θ(z_β) − ∂E/∂θ(z*) ]`(单边,bias O(β))。 + +它能成立,**默认了四条前提**: + +| | 前提 | 含义 | +|---|---|---| +| **A** | **保守 / 对称** | 存在标量能量 E,所以 Jacobian `J = ∂F/∂z` 对称(`J = Jᵀ`)。 | +| **B** | **自由相已收敛** | 读出在真不动点 `z*` 上;残差 ≈ 0。 | +| **C** | **小 β 线性响应 + 干净 nudge** | β→0,钳制只是微扰;为防发散加的 clamp 不影响估计。 | +| **D** | **训练中不动点始终稳定存在** | 每一步权重更新后,自由相仍收敛到一个稳定不动点。 | + +**Transformer 把这四条全破了。** 当前版本的每一处修改,都是在修复其中一条。下面分四块。 + +--- + +## 破缺 A:保守性 —— 力形式 EP + AEP 修正 + +**标准怎么做**:从能量 E 出发,`F = −∇E`,梯度用 `∂E/∂θ`。 + +**为什么失效**:transformer block 没有能量。注意力 `Q≠K`(非互易耦合)、untied FFN +(`W1≠W2ᵀ`)使 `J ≠ Jᵀ`,**写不出 E**。强行用能量路线(`energy` 模式:tied-value +LSE 能量)可以保守化,但代价是表达力(实测 thick 1.95 vs energy/mono 2.11,差 0.15–0.2 CE)。 + +**改法(两步)**: + +1. **力形式 EP(VF,vector-field readout)**——丢掉能量,直接把动力学写成力 `F(z)`,relax + `ż=F(z)` 到不动点。梯度不再用 `∂E/∂θ`,改用**向量场读出**: + ``` + ∂L/∂θ ≈ ∂/∂θ ⟨a, F(z*; θ)⟩ a = 对比态(见破缺 C) + ``` + 这一步在不动点处只调一次 autograd(逐项局部记账,不是 BPTT)。注意力、FFN、LN、 + embedding 全是同一个 F 的项,**联合训练,无分模块调度**。 + ⚠️ **更正(非我们发明)**:力形式 VF 是**已有方法**,正是 AsymEP(arXiv:2602.03670)论文里的 + baseline。而且 **VF 单独用在非保守系统上会崩**(他们 CIFAR-10 上 VF=10% 随机、MNIST 64% vs + AsymEP 92.7%)——这恰好对应我们实测的"无修正注意力 cos≈0.25"。所以这一步**不是我们的贡献**, + 它是"会崩的起点";真正修好它的是下面第 2 步的反对称修正(也是 AsymEP 的)。 + +2. **AEP 修正(非保守修复)**——力形式下,naive nudge 围绕 z* 线性化用的是 J,但正确的 + 伴随(adjoint)需要 **Jᵀ**。不修正 → 注意力梯度 cos ≈ 0.25(基本是错的)。 + 修正:给 nudge 力加上 `−(Jv − Jᵀv)`,`v = z − z*`,无矩阵实现 = 一次 jvp + 一次 vjp。 + 作用:把 nudge 相 Jacobian `J → Jᵀ` ⇒ a 逼近真正的伴随响应 ⇒ **`Q≠K` 注意力梯度 + cos 恢复到 0.99–1.0**(实测:attn 0.99 / ffn 1.00 / 整块 0.99)。 + - 来源:AEP "EP for non-conservative systems"(arXiv:2602.03670)。 + - 关键性质:修正项在 (z−z*) 上**线性、实系数** → 不破坏后面全纯估计的解析性。 + - 注意:修正在 **z* 处线性化**,所以 nudge 轨迹必须留在线性响应窗内(T2≈20 在窗内; + T2≳150–300 会逸出,见硬件孪生那段的 horizon 限制)。 + +**代码**:`lt_ep_train.py::force`(thick 力)、`ep_step` 的 `nudge()` 内 AEP 块、 +`nc_force`(非保守部分,供 AEP/jacreg 用)。`--attn_mode thick`。 + +--- + +## 破缺 B:收敛性 —— 残差是健康信号 + 自适应 T1 + +**标准怎么做**:固定 T1,假定已收敛,读出直接在"z*"上。 + +**为什么失效**:EP 估计器有一个**有效性阈值**(实测,非假设):梯度 cos vs 精确参考随自由相 +相对残差 `res = ‖z⁺−z*‖/‖z*‖` 急剧退化: +``` +res ≈ 5e-5 → cos 0.85–0.88 +res ≈ 1e-3 → cos 0.2–0.9(看 batch) +res ≈ 3e-3 → cos ≈ 0–0.5 +res ≈ 1e-2 → 噪声 +``` +**BPTT 没有这个阈值**(它对实际有限计算求导,收不收敛都给一个自洽梯度)。这条不对称—— +而非任何更深的东西——就是 EP 专属的难点。 + +**改法**: +1. **把 res 立为头号健康信号**(不是 spectral radius——见破缺 D)。每步多走一步 relax 测 res。 +2. **自适应 T1**:固定 T1=150 测到 res 后,若仍 > `res_est`(默认 1e-4),按 50 步一段继续 + relax 直到 res≤阈值或到 `t1max` 上限。**用算力买紧致**。 + - 重要细节:λ-控制器的 res 信号**仍在固定 T1=150 处采样**(保持控制器语义不变,不引入新的 + λ 战争);只有送进 nudge 的 z* 被 refine 到更紧。 + +**代码**:`ep_step` 顶部 `res` 计算 + `t1max/res_est` 的 while-refine 块。 +`--t1max 500 --res_est 1e-4`。 + +--- + +## 破缺 C:小 β 线性响应 + 干净 nudge —— 对称 nudge + 全纯估计 + 自适应 T2 + +**标准怎么做**:单边 +β(bias O(β)),且为防 nudge 把 relax 推爆,在 nudge 里加 **clamp** +(对输出做硬投影 `g.clamp(±2)`、对 AEP 修正做 `‖corr‖≤‖F‖` 裁剪)。 + +**为什么失效**: +- 单边 β 的 O(β) bias 逼着 β 很小,而估计噪声 ∝ (收敛误差)/β,小 β 放大噪声。 +- **clamp 是非解析的硬投影**。实测:在边缘残差(res 1.6e-3)处,clamp 是**估计误差的主因** + ——plain EP cos 0.27,去掉 clamp 后 0.89。clamp 本是为早期训练护航,却在中期残差一漂高就 + **悄悄毒化每一步更新**。 + +**改法(三步)**: + +1. **对称(两边)nudge**——`a = (z₋ − z₊)/(2β)`,centered ⇒ bias O(β²)(Laborieux 2021)。 + +2. **全纯 EP(clamp-free,复圆 Cauchy 读出)**(Laborieux–Zenke 2022)——把 ±β 换成复平面圆 + `|β|=r` 上的 N 个点 `β_k = r·e^{2πik/N}`,relax **全纯延拓**的力,读 + ``` + a = −Re[ (1/Nr) Σ_k e^{−iφ_k} (z_k − z*) ] bias O(r^N) + ``` + bias 从 O(r²) 降到 O(r^N) ⇒ r 可大 5–10×(等 bias),1/β 噪声放大同比例下降。 + - 力的全纯延拓:手写复 LN(非共轭方差)、softmax 用 exp 比值、GELU 用 tanh 形(整函数)。 + - **nudge 里完全无 clamp**(clamp 非解析,会毁掉 O(r^N) 阶);改为监控 `max|z−z*|`。 + - AEP 修正实系数线性 ⇒ 保解析,对实/虚部分别施加即可。 + - 实测扫描:N(2…8)与 r(0.02…0.2)基本持平 ⇒ **有限-β bias 和 1/β 噪声在此尺度都不是 + 瓶颈**;剩余 ~0.12 错位是 T2 截断(下一步)。 + +3. **自适应 T2(后见之明快照选择)**——T2 截断值 ~0.12 cos;但慢混合 batch 上长 T2 会发散 + (非正规瞬态增长;**基于步长的早停会失败**,瞬态在 t≈6–39 误触发)。解法:lockstep 跑到 + T2max,每 K 步快照对比态 `a_t`,**返回增量最小(最稳定)的那个快照**,只在明确爆炸 + (增量 > 5× 运行最小值)时早停。**判据是"关心的量 a 的增量",不是步长** ⇒ 瞬态增长无害。 + 实测:never worse than 固定 T2=20;mean cos 0.871 → 0.932。 + +**代码**:`holo_ep.py`——`cln/csoftmax_masked/cgelu/cforce`(全纯力)、`holo_a`(Cauchy 读出)、 +`holo_a_select`(自适应 T2)、`holo_a_select2`(N=2 相位-batched 快路,与 select 数值等价)。 +旧 clamp 在 `ep_step::nudge` 内 `g.clamp(±2)`/`corr` 裁剪——**已被全纯路线取代**(那是 legacy 路径)。 +`--holo 2 --hr 0.2 --t2sel 120`。 + +--- + +## 破缺 D:训练中不动点始终稳定存在 —— λ控制器 + 验证门 + 熔断 + 架构稳定器 + +**标准怎么做**:假定每步更新后自由相仍收敛到稳定不动点。 + +**为什么失效**:**训练会把动力学推离收缩流形**。这不是 EP 特有的——实测连 bare BPTT(精确梯度) +跑到 14k 也会走出收缩流形(res→4.7e-2,best 退化到 2.021,比它 3k 还差)。一旦离开,EP 估计器 +进入无效区(破缺 B 的阈值),更新方向变错,正反馈把权重推得更远。 + +**改法(四件,从软到硬)**: + +1. **残差驱动的连续 λ-控制器(软 Jacobian 惩罚)**——核心稳定器。 + - 惩罚项:`λ‖J_nc(z*)‖²_F`,Hutchinson 估计(一次 jvp on 随机探针,对 θ 求导;Bai 2021)。 + 保持自由相收缩 ⇒ 把估计器留在有效区。 + - 控制律(每步):`λ ← clip( λ · (res_EMA / target)^0.3 , floor , max )`。 + - **为什么控 res 不控谱半径**:block Jacobian 高度**非正规**——瞬态增长对特征值不可见 + (实测 ρ(J)=0.94"稳定"而 relax 发散到 res 0.21)。一步残差**就是**那个瞬态,控它。 + - **信号上的 EMA(0.9)**:原始 res 噪声大,乘性控制器在噪声上会随机游走(实测 λ 在 0.5↔13 + 抖),抖动的 λ 本身又扰动训练。EMA 去掉抖动。 + - **target ≈ 5e-4**:刚进有效阈值内(few×1e-4)留余量;不更紧(更紧白费算力且伤表达力)。 + 参考:BPTT 自己的最优在 res 1e-3–2e-2——好解只是**轻度收缩**;我们比 BPTT 多要一点, + 因为是**估计器**需要。 + - **floor 是承重的(不可退火到 0)**:两次独立实验证明 λ≲0.02 在任何阶段都致命 + (R2 从头 λ→0、R6 λ-floor 随 lr 退火,都死于同一种死法:val CE 60–77 且 res≡0)。 + 事后剖析:这是**被浮点伪装成收敛的爆炸**——‖z*‖与无 cap 参数在临时无效梯度下涨到 + `ε·F < ulp(z)`,relax 冻结(res=0 是吸收造成的),logits 巨大且自信地错。λ 惩罚的 + θ-梯度触及 fc/pj/LN/attn,正是把这个盆地挡在门外的机制。 +2. **验证门(validity gate)**——当 `res_used > res_gate`,EP 更新在数学上无定义 ⇒ **只施加 + homeostat(jacreg),完全跳过 nudge**(快速恢复步)。S1 尺度实测它是承重的:门之前死三次, + 门之后活。`--res_gate`。 +3. **熔断(abort fuse)**——`res > abort_res`(默认 0.1)**连续 100 步** ⇒ 停,保留 best ckpt。 + 纯安全网,防止无效区里烧机时。`--abort_res 0.1`。 +4. **架构层稳定器(尺度变大时才需要)**: + - **resinit**(ReZero/Fixup:WO、pj 乘 0.1)——初始化即近恒等收缩块,大宽度起步稳。`--resinit 0.1`。 + - **qknorm**(Qwen3 q/k RMSNorm)——bound 注意力 logits/Jacobian。`--qknorm`。 + - **阻尼 −c·z**——给原始注意力造/强化不动点;对 thin/real 必需。但 thick 里 LN 在内部,阻尼 + 反而抬高有效 Jacobian(`J∝1/σ(z)`),故 thick 把 c 设小(=1),稳定器交给 λ 惩罚。 + - **权重范数 cap**(WQ/K/V/O/Wm/Wh/fc/pj renorm 到 3× init)——瞬态期的钝安全网,健康时少触发。 + +**代码**:`ep_step` 的 jacreg 块 + 验证门分支;`main` 训练循环里的控制律(`jr = min(jr_max, +max(flo, jr*exp(0.3*log(rs/rtgt))))`)、`badct` 熔断、cap renorm、`resinit/qknorm` 注入。 + +--- + +## 外壳:与 EP 正交但必需的工程层 + +这些不是"EP 的修改",但当前版本依赖它们才能跑到当前数字: + +- **读出头 Wh**:用它自己的局部 CE 梯度 `∂CE/∂Wh`(在自由 z* 上),**不**穿过动力学。 + 这是 EP 设定的标准做法,不是 BP。 +- **参数 EMA**(decay 0.999,与裸权重并行评估)——压late-phase 估计器噪声漂移,不碰稳定环。`--pema`。 +- **优化器 / 调度**:AdamW(lr 1e-3, wd 1e-4)、warmup→cosine、grad-norm clip 5.0、跳过 non-finite 步。 + - **warmup 对大模型承重**:让控制器先建立收缩,再放大步长把权重踢出盆地。(BP baseline 不用 + warmup 也稳——这是 EP↔BP 的一个不对称。) +- **lr 标定(k 标定)**:`k = |g_EP|/|g_BPTT|` 每组,`lr_EP = lr_BPTT/k`。**但 AdamW 逐坐标归一化 + ⇒ 吸收掉 k 的尺度 ⇒ 对 Adam k 失效**;k 只在 SGD/硬件(固定增益线)下才重要。 + - 理论基础:Ernoult 2019——自由相收敛 + β→0 时 **EP ≡ BPTT**(逐步),所以 lr 的对应是 + EP↔BPTT,**不是** EP↔BP(BP 与 BPTT 物理含义不同,lr 本就不该直接对应)。 + +--- + +## 一页速查表:修改 → 破坏的前提 → flag → 实测效果 + +| 修改 | 破坏的标准前提 | flag | 实测效果 | +|---|---|---|---| +| 力形式 VF *(已有,非我们)* | A 保守 | `--attn_mode thick` | 写出无能量 EP,但**单独用会崩**(cos≈0.25) | +| AEP 反对称修正 J→Jᵀ *(AsymEP 的,非我们)* | A 保守 | (thick 内置) | 注意力梯度 cos 0.25 → 0.99 | +| ↳ 我们:无矩阵化 + 上注意力 + 全纯结合 + 共模追踪 | (scale/工程) | jvp−vjp / `--holo` / `holo_a_track` | 让上面两条能跑 transformer LM | +| 对称 nudge | C 小β | (默认) | bias O(β)→O(β²) | +| 全纯 + clamp-free | C 干净nudge | `--holo 2 --hr 0.2` | 边缘残差 cos 0.27 → 0.89;r 可放大 10× | +| 自适应 T2 选择 | C 线性响应 | `--t2sel 120` | mean cos 0.871 → 0.932;训练 −0.064 | +| 残差为健康信号 | B 收敛 | (内置) | 暴露有效性阈值 res≲few×1e-4 | +| 自适应 T1 | B 收敛 | `--t1max 500 --res_est 1e-4` | 把 z* refine 进有效区,长 T2 才有收益 | +| λ-控制器(软Jac惩罚) | D 稳定 | `--jacreg .. --res_target 5e-4 --res_ema 0.9` | 保持收缩;floor 不可退火(否则 fp-吸收爆炸) | +| 验证门 | D 稳定 | `--res_gate` | S1:门前死3次,门后活 | +| 熔断 | D 稳定 | `--abort_res 0.1` | 连续100步 res>0.1 即停,保 best | +| resinit / qknorm | D 稳定(大宽度) | `--resinit 0.1 --qknorm` | 大 C 起步稳;bound 注意力 Jacobian | + +**一句话**:标准 EP 假定"保守、已收敛、小β干净、训练中恒稳";transformer 四条全破。 +A 用**力形式 + AEP** 修(把 J 变 Jᵀ);B 用**残差信号 + 自适应 T1** 修(进有效区); +C 用**全纯 clamp-free + 自适应 T2** 修(干净估计 + 不截断不发散);D 用**残差驱动 λ-控制器 +(floor 承重)+ 门 + 熔断 + 架构稳定器** 修(训练中拽回收缩流形)。其中 A 的 AEP 与 D 的 +λ-控制器是两处最实质的偏离;其余多是"把估计器修干净 / 留在有效区"。 |
