summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CLAUDE.md2
-rw-r--r--NOTE.md22
-rw-r--r--configs/cifar10.yaml15
-rw-r--r--configs/toy_lq.yaml14
-rw-r--r--experiments/__init__.py0
-rw-r--r--experiments/__pycache__/__init__.cpython-313.pycbin0 -> 137 bytes
-rw-r--r--experiments/__pycache__/toy_lq.cpython-313.pycbin0 -> 19620 bytes
-rw-r--r--experiments/cifar_resmlp.py775
-rw-r--r--experiments/plot_results.py327
-rw-r--r--experiments/plot_toy_final.py183
-rw-r--r--experiments/toy_lq.py395
-rw-r--r--experiments/toy_lq_sweep.py243
-rw-r--r--experiments/toy_lq_v2.py327
-rw-r--r--methods/__init__.py0
-rw-r--r--metrics/__init__.py0
-rw-r--r--metrics/__pycache__/__init__.cpython-313.pycbin0 -> 133 bytes
-rw-r--r--metrics/__pycache__/credit_metrics.cpython-313.pycbin0 -> 6418 bytes
-rw-r--r--metrics/credit_metrics.py156
-rw-r--r--models/__init__.py0
-rw-r--r--models/__pycache__/__init__.cpython-313.pycbin0 -> 132 bytes
-rw-r--r--models/__pycache__/residual_mlp.cpython-313.pycbin0 -> 4692 bytes
-rw-r--r--models/__pycache__/state_bridge.cpython-313.pycbin0 -> 2468 bytes
-rw-r--r--models/__pycache__/value_net.cpython-313.pycbin0 -> 5308 bytes
-rw-r--r--models/residual_mlp.py73
-rw-r--r--models/state_bridge.py35
-rw-r--r--models/value_net.py77
-rw-r--r--report/toy_bridge_residual.pngbin0 -> 70000 bytes
-rw-r--r--report/toy_per_layer_diagnostics.pngbin0 -> 166670 bytes
-rw-r--r--report/toy_term_grad_effect.pngbin0 -> 80994 bytes
-rw-r--r--report/toy_training_curves.pngbin0 -> 179278 bytes
-rw-r--r--results/smoke_test/results_fashionmnist.json511
-rw-r--r--results/smoke_test2/results_fashionmnist.json721
-rw-r--r--results/toy_lq/state_bridge_seed42.ptbin0 -> 150069 bytes
-rw-r--r--results/toy_lq/sweep_results.json1070
-rw-r--r--results/toy_lq/toy_lq_seed42.json335
-rw-r--r--results/toy_lq/toy_lq_v2_seed123_lam0.1_sig0.1_tgw1.0_fm0.0.json330
-rw-r--r--results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.0.json458
-rw-r--r--results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.1.json330
-rw-r--r--results/toy_lq/toy_lq_v2_seed42_lam1.0_sig0.3_tgw0.0_fm0.0.json282
-rw-r--r--results/toy_lq/toy_lq_v2_seed456_lam0.1_sig0.1_tgw1.0_fm0.0.json330
-rw-r--r--results/toy_lq/value_net_seed42.ptbin0 -> 117515 bytes
41 files changed, 7011 insertions, 0 deletions
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 0000000..6d9d08f
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,2 @@
+chor 的前提下,能不能学到一种有用的、terminal-conditioned 的局部 credit signal,用来替代 Feedback Alignment (FA/DFA) 中固定随机反馈矩阵带来的粗糙 credit assignment。 我们的目标不是立刻追求最高分类精度,而是验证一个新的假设是否可行:与其桥接 hidden state,不如桥接 credit / costate / value field。 也就是说,我们希望学习的对象不是 h_l \mapsto h_{l+1} 或者 h_l \mapsto h_L, 而是一个 terminal-conditioned 的标量 value / desirability 场 V_\phi(h_l, t_l, s), \qquad t_l = l/L, 并把每层的局部信用分配定义为 a_l := \nabla_{h_l} V_\phi(h_l,t_l,s). 这里 s 是来自终端误差的全局调制码。最简单的做法是 e_T = \nabla_{\hat y}\ell(\hat y,y),\qquad s = P^\top \mathrm{sg}(e_T), 其中 \mathrm{sg}[\cdot] 表示 stop-gradient。对于 10 类分类任务,可以先直接取 P=I,所以 s=e_T\in\mathbb R^{10}。 项目分成两个阶段。第一阶段是一个线性二次 residual sanity check,它的作用不是追任务性能,而是验证"credit bridge 学到的 a_l 到底像不像真正的局部梯度"。第二阶段是一个深 residual MLP 主实验,用来比较 BP、DFA、state bridge 和 credit bridge。 整个项目的最重要约束是:训练时不允许在隐藏层使用 BP anchor。 也就是说,中间层不能用 exact backprop 的 hidden gradient 作为监督信号。允许的只有两件事:一是输出层本地 exact gradient,因为输出层本来就直接看到 loss;二是离线 evaluation 时计算 BP hidden gradient 作为诊断指标,但绝对不能用于训练。 请优先按下面的顺序推进:先把线性 sanity check 完成并拿到清晰诊断结果,再做主实验。不要反过来。 ⸻ 一、项目背景与核心假设 标准 BP 在隐藏层使用 \delta_l^{\mathrm{BP}} = J_{l\to L}^\top e_T, 其中 J_{l\to L} 是从层 l 到输出的 Jacobian。FA/DFA 则用一个固定随机反馈矩阵近似这个对象,例如 DFA 用 \delta_l^{\mathrm{DFA}} = B_l e_T. 我们的怀疑是:问题的关键不在于"找一个更好的静态 B_l",而在于学习一个状态相关、样本相关、深度相关的信用场。也就是说,我们想把隐藏层该收到的 top-down signal 看成某种 terminal-conditioned value field 的梯度: a_l = \nabla_{h_l} V_\phi(h_l,t_l,s). 这个想法可以理解成一种无中间 anchor 的 credit bridge。它和单纯预测 hidden state 的 state bridge 是不同对象:state bridge 试图从 h_l 预测 h_L 或中间状态,而 credit bridge 直接学习 loss 对局部状态的敏感度。 我们需要验证两个具体命题: 第一,state bridge 是否不足以产生有用的局部 credit。也就是说,即使它能把 h_L 预测得不错,它算出来的 a_l^{\mathrm{state}} = \nabla_{h_l} \ell(W_{\rm out}\hat h_L,y) 也未必能真的指导局部更新。 第二,credit bridge 在没有中间 BP anchor 的情况下,能否产生比 DFA 更有用的局部 credit。这里"更有用"主要通过局部扰动相关性、nudging test、离线 BP cosine 等诊断指标体现,而不只是最终 test accuracy。 ⸻ 二、你需要实现的方法 你至少要实现下面 4 个方法,并保证所有方法使用尽可能相同的前向架构、优化器族和训练设置。 方法 1:BP baseline 这是上界基线。在这个方法里,允许标准 end-to-end backprop。它的主要作用是给出 accuracy upper bound,并在 evaluation 阶段提供 hidden gradient 参考。这个方法训练时当然可以调用整网 loss.backward()。 方法 2:DFA baseline 这是无 hidden BP anchor 的强基线。设网络有 L 个 residual blocks,每个隐藏状态维度为 d,类别数为 C。对每个 block 采样一个固定随机矩阵 B_l \in \mathbb R^{d\times C}, 并保持训练中不变。输出误差是 e_T = \nabla_{\hat y}\ell(\hat y,y). 则该层使用 a_l^{\mathrm{DFA}} = B_l e_T. block l 的参数更新使用局部 surrogate: \mathcal L_l^{\mathrm{local}} = \langle F_l(h_l;\theta_l),\mathrm{sg}[a_{l+1}^{\mathrm{DFA}}]\rangle. 请注意:DFA 方法里不能因为偷懒而让 loss.backward() 穿过隐藏层。每个 block 的更新必须只依赖自己的局部前向和本层收到的固定反馈。 方法 3:State bridge baseline 这是我们故意加入的"对象错位"基线。它的目的不是追最好效果,而是验证"桥接 state 不等于桥接 credit"。 做法如下。定义一个共享或半共享的预测器 G_\psi(h_l,t_l,s)\to \hat h_L^{(l)}. 它的输入是当前层 hidden state、层位置 t_l=l/L、以及终端调制码 s。它的训练目标是终端状态回归: \mathcal L_{\mathrm{state}} = \sum_{l=0}^{L-1}\|G_\psi(h_l,t_l,s)-\mathrm{sg}[h_L]\|_2^2. 这里建议只用 l<L 的层做监督,因为 l=L 是平凡点。训练 state bridge 时,输入 hidden state 必须是 detach().requires_grad_(True) 的副本,避免 state predictor 的训练反向影响前向网络。 训练好或同步训练 G_\psi 后,定义 state-bridge credit: a_l^{\mathrm{state}} = \nabla_{h_l}\ell(W_{\rm out}G_\psi(h_l,t_l,s),y). 这里 a_l^{\mathrm{state}} 是对 predictor 输入 h_l 的梯度,不是对前向网络参数的梯度。然后像 DFA 一样,用它做 block 的局部 surrogate 更新: \mathcal L_l^{\mathrm{local}} = \langle F_l(h_l;\theta_l),\mathrm{sg}[a_{l+1}^{\mathrm{state}}]\rangle. 方法 4:Credit bridge(主方法) 这是项目最核心的方法。定义一个标量 value 网络 V_\phi(h_l,t_l,s)\in\mathbb R. 它的输入仍然是 hidden state、深度时间 t_l=l/L、终端调制码 s。然后定义该层信用信号为 a_l = \nabla_{h_l}V_\phi(h_l,t_l,s). 这个 a_l 将被送给对应 block 做局部 surrogate 更新。 credit bridge 的训练先从最简单版本开始,只实现两项 loss。 第一项是终端边界项: \mathcal L_{\mathrm{term}} = \Big( V_\phi(h_L,1,s)-\mathrm{sg}[\ell(\hat y,y)] \Big)^2. 第二项是一阶 bridge consistency。设前向 block 是 residual 形式 h_{l+1}=h_l+F_l(h_l;\theta_l). 引入小噪声参考动力学 \tilde h_{l+1}^{(k)} = h_l + F_l(h_l;\theta_l) + \sigma_l \xi_k,\qquad \xi_k\sim\mathcal N(0,I). 再用一个 EMA target network \bar\phi 构造 target: \hat V_l^{\mathrm{tgt}} = -\lambda \log \left( \frac1K\sum_{k=1}^K \exp\Big( -\frac{V_{\bar\phi}(\tilde h_{l+1}^{(k)},t_{l+1},s)}{\lambda} \Big) \right). bridge loss 定义为 \mathcal L_{\mathrm{bridge}} = \sum_{l=0}^{L-1} \Big( V_\phi(h_l,t_l,s)-\mathrm{sg}[\hat V_l^{\mathrm{tgt}}] \Big)^2. 完整的 feedback loss 为 \mathcal L_\phi = \mathcal L_{\mathrm{term}} + \mathcal L_{\mathrm{bridge}}. 在这个最简版本先不要上 FM auxiliary,也先不要上 MaxCal smoothness 正则。原因很现实:如果一开始就把所有二阶项都加进去,工程复杂度会高很多,不利于快速判断方向是否有信号。 当前向网络更新时,block l 的局部 surrogate 用 \mathcal L_l^{\mathrm{local}} = \langle F_l(h_l;\theta_l),\mathrm{sg}[a_{l+1}]\rangle. 输出层权重 W_{\rm out} 可以直接用 CE 的 exact output gradient 更新,这是允许的,因为输出层本地就能看到 loss。 如果最简版本有信号,再做一个第二版:在 toy 和小模型上加一个 FM-style auxiliary,用来让 credit field 在层与层之间更平滑。定义随机中间点 \bar h_{l,\tau} = h_l+\tau F_l(h_l;\theta_l)+\sqrt{\tau(1-\tau)}\sigma_l\epsilon, \qquad \tau\sim U(0,1),\ \epsilon\sim\mathcal N(0,I), 以及插值 target \bar a_{l,\tau}^{\rm tgt} = (1-\tau)\,\mathrm{sg}[a_l] + \tau\,\mathrm{sg}[a_{l+1}]. 然后加上 \mathcal L_{\mathrm{fm}} = \gamma\sum_{l=0}^{L-1} \left\| \nabla_h V_\phi(\bar h_{l,\tau},(l+\tau)/L,s) - \bar a_{l,\tau}^{\rm tgt} \right\|_2^2. 请注意:这一项会引入对 \nabla_h V_\phi 的训练,也就是二阶 autograd。先只在 toy 和小规模主实验上做,不要一上来全量开启。 ⸻ 三、实验阶段 A:线性二次 residual sanity check 这个实验的作用是:在一个 exact costate 可解析的系统里,验证 credit bridge 学出来的 a_l 是否接近真实的局部梯度。这个阶段前向动力学参数全部固定,不训练 forward net,只训练 feedback / bridge 模型。这样可以把"credit 学习是否正确"和"前向训练是否稳定"分离开。 请实现如下系统。设隐藏维度 d=64,终端输出维度 m=10,层数 L=12。对每层采样一个稳定线性映射 M_l = I + A_l, 其中 A_l 是一个随机矩阵,但要缩放到谱范数较小,例如 \|A_l\|_2\le 0.05。前向动力学是 h_{l+1}=M_l h_l + \sigma \xi_l,\qquad \xi_l\sim\mathcal N(0,I),\ \sigma=0.03. 数据用在线生成即可。可以取 h_0 \sim \mathcal N(0,I_d),\qquad y\sim \mathcal N(0,I_m), 终端损失定义为 \Phi(h_L,y)=\frac12\|Ch_L-y\|_2^2, 其中 C\in\mathbb R^{m\times d} 是固定随机矩阵。 这个系统的 exact costate 是解析可得的。终端梯度是 a_L^{\mathrm{exact}} = C^\top(Ch_L-y), 并且递推为 a_l^{\mathrm{exact}} = M_l^\top a_{l+1}^{\mathrm{exact}}. 请用这个 exact costate 做 evaluation,但不要拿它训练 credit bridge。 这个阶段至少比较 3 个方法:DFA、state bridge、credit bridge。BP 在这里不是必须训练的,因为 forward net 固定;你只需要 exact costate 作为评价参考。 这个阶段的主要评价指标如下。 第一,exact costate cosine: \Gamma_l = \mathbb E\Big[ \cos(a_l,a_l^{\mathrm{exact}}) \Big]. 要按层报告,也要报告 across-layer 平均。 第二,局部扰动相关性。对每层采样 M 个随机方向 v_{l,j}\sim \mathcal N(0,I),归一化后取小扰动 \epsilon。预测的一阶 loss 变化是 \Delta_{l,j}^{\mathrm{pred}} = \langle a_l,\epsilon v_{l,j}\rangle. 真实变化是从该层起重新滚动后续动力学得到的终端损失差: \Delta_{l,j}^{\mathrm{true}} = \Phi(h_l+\epsilon v_{l,j})-\Phi(h_l). 请计算每层的 Pearson 相关 \rho_l = \mathrm{corr}\!\left(\Delta_{l,j}^{\mathrm{pred}},\Delta_{l,j}^{\mathrm{true}}\right). 这个指标非常重要,因为它不依赖 BP,只检验 a_l 是否真是 loss 对局部状态的有效线性近似。 第三,nudging test。对每层做 h_l' = h_l - \eta \frac{a_l}{\mathrm{RMS}(a_l)+10^{-6}}, 然后从该层继续滚动动力学,看终端损失是否下降: \Delta_l^{\mathrm{nudge}} = \Phi(h_l')-\Phi(h_l). 一个好的 credit 应该让 \Delta_l^{\mathrm{nudge}}<0,而且优于 random direction 和 DFA。 第四,bridge residual。对 credit bridge,计算 R_l = \left| V_\phi(h_l,t_l,s) + \lambda\log \left( \frac1K\sum_k e^{-V_{\bar\phi}(\tilde h_{l+1}^{(k)},t_{l+1},s)/\lambda} \right) \right|. 这个指标用来检查 bridge recursion 有没有学出来。 这个阶段的目标不是做到完美,而是回答一个问题:在完全无中间 BP anchor 的条件下,credit bridge 至少能不能比 state bridge 更像一个真正的局部梯度对象。如果 toy 上完全没信号,就不要进入主实验;如果 toy 上 \Gamma_l、\rho_l、\Delta_l^{\mathrm{nudge}} 都明显优于 state bridge 和 DFA,再继续。 ⸻ 四、实验阶段 B:深 residual MLP 主实验 主实验使用图像分类。优先数据集是 CIFAR-10。如果 CIFAR-10 调试太慢,可以先用 FashionMNIST 做 smoke test,但最终请把主结果放在 CIFAR-10 上。 前向架构请用一个深 residual MLP,而不是 Transformer。原因是这个阶段的目标是验证 credit 语义,不是挑战大模型训练。推荐配置如下: 输入先展平,再经过一个线性 embedding 到隐藏维度 d=512 或 d=768。总共有 L=12 个 residual blocks。每个 block 用 pre-LayerNorm + 两层 MLP: h_{l+1} = h_l + W_{2,l}\,\mathrm{GELU}(W_{1,l}\,\mathrm{LN}(h_l)). 输出头为 \hat y = W_{\rm out}\,\mathrm{LN}(h_L). 训练损失为标准 cross-entropy。所有方法都用相同的数据增强、batch size、优化器族、weight decay、训练 epoch 数。建议先用 AdamW,batch size 128,训练 100 到 200 epochs。所有方法跑 3 个随机种子。 主实验必须比较这 4 个方法:BP、DFA、state bridge、credit bridge。credit bridge 如果工程可承受,再加一个 credit bridge + FM auxiliary 版本作为附加实验。 这个阶段有两个非常重要的实现约束。 第一,对于非 BP 方法,绝对不要让全局 loss 反向穿过隐藏层。 正确做法是:先用前向网络得到 h_0,\dots,h_L 和终端误差 e_T。然后为 feedback 网络构造 detached hidden copies: \tilde h_l = \mathrm{detach}(h_l),\qquad \tilde h_l.\mathrm{requires\_grad}=True. 用这些 detached copies 训练 state bridge 或 value net。这样 feedback 网络的梯度不会偷偷训练前向网络。 第二,更新每个 block 参数时,也必须只依赖 block 自己的局部前向和来自 feedback 的 credit。具体地,block l 更新时应重新取 \bar h_l = \mathrm{detach}(h_l), 计算 F_l(\bar h_l;\theta_l),然后最小化 \mathcal L_l^{\mathrm{local}} = \langle F_l(\bar h_l;\theta_l), \mathrm{sg}[a_{l+1}] \rangle. 这样 autograd 只会流向 \theta_l,不会流到更早的 block。每个 block 可以有独立 optimizer,也可以共享 optimizer 但要小心逐块 zero_grad() 和逐块 step()。如果你觉得容易出错,我更建议每个 block 用独立 optimizer。 输出头 W_{\rm out} 的更新可以直接用 CE 的 exact output gradient,但更新时请把 h_L detach 掉,避免反向进隐藏层。 ⸻ 五、主实验评估指标 最终 accuracy 当然要报,但它不是唯一核心指标。我们真正想知道的是:credit bridge 产生的局部 credit 是否比 DFA 和 state bridge 更像 loss 对局部状态的真实敏感度。 请至少计算并汇报下面这些指标。 1. Validation accuracy / train loss / test accuracy 这是标准指标,但不要只看它。所有方法都要报 mean±std over 3 seeds。 2. 局部扰动相关性 \rho_l 这是最重要的无-BP 诊断指标。做法如下:对某个 validation batch,取某层 hidden state h_l,采样 M 个随机方向 v_{l,j},比如 M=16 或 32。对每个方向,计算预测 loss 变化: \Delta_{l,j}^{\mathrm{pred}} = \langle a_l, \epsilon v_{l,j}\rangle. 然后把 hidden state 在该层替换为 h_l+\epsilon v_{l,j},只重新跑从该层到输出的 tail,得到真实 loss 变化: \Delta_{l,j}^{\mathrm{true}} = \ell(h_l+\epsilon v_{l,j})-\ell(h_l). 请计算 Pearson correlation: \rho_l = \mathrm{corr}(\Delta_{l,j}^{\mathrm{pred}}, \Delta_{l,j}^{\mathrm{true}}). 这个指标应按层报告随训练进程的变化。一个有用的 credit 应该在多层上保持正相关,并且优于 DFA。 3. Hidden nudging test 对每层定义归一化 nudging: h_l' = h_l - \eta \frac{a_l}{\mathrm{RMS}(a_l)+10^{-6}}. 然后从该层往后重跑 tail,比较 loss 变化: \Delta_l^{\mathrm{nudge}} = \ell(h_l') - \ell(h_l). 好的 credit 应该让 \Delta_l^{\mathrm{nudge}} 为负,并且优于 random direction 和 DFA。请固定一组 \eta 值,例如 \{10^{-3},3\cdot10^{-3},10^{-2}\},都测一下,避免因为步长选择不当导致误判。 4. Offline BP cosine 这个指标在训练中不能用,但 evaluation 可以用。对 validation batch 单独跑一次完整 BP,得到真实 hidden gradient \delta_l^{\mathrm{BP}} = \frac{\partial \ell}{\partial h_l}. 然后计算 \Gamma_l = \mathbb E\left[\cos(a_l,\delta_l^{\mathrm{BP}})\right]. 按层报告,作为辅助诊断。这个指标很有价值,但它是 secondary,因为我们的核心要求是无 hidden BP anchor 训练,而不是 offline 完全不算 BP。 5. Bridge residual 只对 credit bridge 报告,用来判断 value recursion 是否学出来。定义同上。 6. Feature drift 对每个 block 参数,计算 M_l = \frac{\|W_l^{\mathrm{final}}-W_l^{\mathrm{init}}\|_F}{\|W_l^{\mathrm{init}}\|_F}. 这有助于看不同方法是不是都停留在近似 lazy regime。请按层汇报。 7. State bridge 的 terminal prediction quality 请额外报告 state bridge 的 \|G_\psi(h_l,t_l,s)-h_L\|_2^2 或平均 relative error。因为我们想证明的不是 "它什么都学不会",而是 "即使它把终端状态预测得还行,也不一定能产生好的 credit"。 ⸻ 六、实现建议与工程细节 请用 PyTorch 2.x。toy 和小模型请优先用 float32,不要一上来混合精度,因为 credit bridge 的二阶图在 mixed precision 下更容易不稳定。 对于 value 网络 V_\phi,建议先用一个小型共享 MLP:输入是 [\mathrm{LN}(h_l), \mathrm{time\_embed}(t_l), s] 的拼接,隐藏层宽度 256 或 512,2 到 3 层,输出标量。time embedding 可以直接用一个小 MLP 处理标量 t_l,也可以用正弦位置编码。关键不是 fancy,而是稳定。 对于 state predictor G_\psi,建议也是共享 MLP,但输出维度是 d。不要做得太大,否则会喧宾夺主。 \lambda 可以先试 0.1 和 1.0。bridge 中的噪声 \sigma_l 可以先设成常数 0.01 到 0.05。target network EMA 动量建议 0.99 或 0.995。每个 \hat V_l^{\mathrm{tgt}} 的 Monte Carlo 样本数 K 先取 4;toy 可以取 8。 DFA 的随机反馈矩阵 B_l 请用高斯初始化,并缩放到合适量级,避免 credit magnitude 爆炸或过小。建议做一个简单规范化,让每层 DFA credit 的 RMS 大致在同一量级。credit bridge 的 a_l 也建议在送入局部 surrogate 前做一版归一化实验: \tilde a_l = \frac{a_l}{\mathrm{RMS}(a_l)+10^{-6}}. 你可以同时保留 raw 和 normalized 两种局部更新版本,但最少要在 pilot 实验里比较一下,否则可能会被纯数值尺度问题误导。 请特别注意一点:如果你在训练 feedback/value 网络时直接用前向图上的 h_l,那 autograd 可能会把 \mathcal L_\phi 的梯度传回前向网络,这会破坏"无 hidden BP anchor"的设定。因此,训练 feedback/value 网络时必须使用 detached hidden copies。更新 block 参数时也必须使用 detached hidden input。请把这一点写进代码注释,并在 README 中明确说明。 ⸻ 七、推荐的代码组织方式 请把代码做成清晰、可复现实验的结构。建议至少有这些文件: models/residual_mlp.py:定义前向 residual MLP。 models/value_net.py:定义 V_\phi。 models/state_bridge.py:定义 G_\psi。 methods/bp.py:BP baseline 训练逻辑。 methods/dfa.py:DFA baseline 训练逻辑。 methods/state_bridge_train.py:state bridge 训练与局部更新逻辑。 methods/credit_bridge_train.py:credit bridge 训练与局部更新逻辑。 experiments/toy_lq.py:线性二次 sanity check。 experiments/cifar_resmlp.py:主实验。 metrics/credit_metrics.py:实现 \rho_l、nudging、BP cosine、bridge residual。 configs/*.yaml:所有实验配置。 README_experiments.md:如何运行。 report/:最终图表与报告。 所有实验都必须保存 config、seed、git commit hash、训练日志和中间 checkpoint。请不要做"手动调了很多但没记录"的实验。 ⸻ 八、优先级与预期产出 请按下面优先级执行,不要乱序。 第一优先级:完成 toy 线性二次 sanity check,并拿到清晰图表。最少要有每层 \Gamma_l、\rho_l、\Delta_l^{\mathrm{nudge}}、bridge residual 的曲线或表格。需要对比 DFA、state bridge、credit bridge。这个阶段如果 credit bridge 比 state bridge 更像真实局部梯度,这就已经是一个积极信号。 第二优先级:在 CIFAR-10 深 residual MLP 上完成 BP、DFA、state bridge、credit bridge 这 4 个方法的主实验。必须至少做 3 个种子。需要输出 accuracy 曲线、\rho_l 分层曲线、nudging test、offline BP cosine、feature drift。最重要的比较不是"谁 accuracy 最高",而是"谁在无 hidden BP anchor 的情况下产生更有 credit 语义的局部信号"。 第三优先级:如果前两步顺利,再做 credit bridge + FM auxiliary。只在 toy 和较小主实验上尝试即可,不需要一开始全量跑。因为这一步需要二阶 autograd,工程成本更高。 ⸻ 九、成功标准与如何判断是否值得继续 这个项目的成功标准不是"credit bridge 最终精度超过 BP",那不现实。更合理的阶段性成功标准是: 第一,在 toy 线性系统上,credit bridge 的 exact costate cosine 和局部扰动相关性明显高于 state bridge,且一般高于 DFA。只要多数层上 \rho_l>0、\Delta_l^{\mathrm{nudge}}<0,并且优于 state bridge,这条线就有价值。 第二,在主实验上,即使最终 accuracy 还没有明显超过 DFA,只要 credit bridge 在中早层的 \rho_l、nudging 和 offline BP cosine 上 consistently 好于 DFA,就说明它确实更接近正确的 credit assignment 对象。 第三,如果 state bridge 的 terminal state regression 做得还不错,但它的 \rho_l 和 nudging 仍然明显不如 credit bridge,那就很好地支持了我们的核心论点:桥接 state 不足,桥接 credit / value field 更对路。 ⸻ 十、你最终必须提交的内容 你最终至少需要交付以下内容: 1. 一份简洁但完整的 README_experiments.md,说明如何运行 toy 和主实验。 2. 可复现实验代码,配置清晰,能从命令行指定方法、数据集、seed。 3. 一份结果报告,最好是 PDF 或 markdown,包含:方法定义、实现说明、关键图表、主要发现、失败点和下一步建议。 4. 至少以下图表: • toy:每层 exact costate cosine、\rho_l、nudging、bridge residual; • 主实验:train/test accuracy、各层 \rho_l、nudging、offline BP cosine、feature drift; • state bridge 的 terminal prediction error vs credit quality 对照图。 5. 一个最终结论段,明确回答三个问题: • 无 hidden BP anchor 时,credit bridge 是否能学出有用 credit? • 它是否比 state bridge 更符合"信用分配对象"的定义? • 它是否在主要诊断指标上优于 DFA? ⸻ 十一、一些不要做的事情 不要把 offline BP cosine 当成训练目标。 不要在隐藏层偷偷使用 BP 监督。 不要为了追精度把局部性破坏掉。 不要一上来把所有复杂正则都打开。 不要只报 accuracy 而不报 credit 诊断。 不要忽略随机种子和实验记录。 不要因为主实验难就跳过 toy sanity check。 ⸻ 十二、如果遇到阻塞,如何降级处理 如果 credit bridge 一开始训练不稳,请按以下顺序降级,而不是乱改: 先只用 \mathcal L_{\mathrm{term}}+\mathcal L_{\mathrm{bridge}},不要上 FM auxiliary。 如果还是不稳,减小模型宽度、减小 \sigma、减小学习率。 如果 credit 大小非常不稳定,先在局部 surrogate 里用 normalized credit。 如果主实验太慢,先在 FashionMNIST 上做 smoke test,但最后仍要回到 CIFAR-10。 如果 main experiment 很难调通,优先保证 toy 实验和诊断指标做完整,因为那是判断方向是否值得继续的关键证据。 ⸻ 你开始之后,请先完成下面这两个最小目标,再继续扩展: 第一个最小目标是:toy 线性系统上,credit bridge vs state bridge vs DFA 的 exact costate cosine 和 \rho_l 曲线。 第二个最小目标是:CIFAR-10 深 residual MLP 上,BP vs DFA vs state bridge vs credit bridge 的 \rho_l 和 nudging 对比。 如果你做完这些最小目标后结果非常差,也不要掩盖失败。请如实报告,并尽量区分"理论方向没信号"和"工程实现还不稳定"这两类失败。再未得到所有结果之前不要停止,可以回看CLAUDE.md或者上网搜索答案。所有结果和你做的尝试以及变动记录到NOTE.md中。用nvidia-smi查看gpu可用性。严禁kill其他用户的进程,未完成前严禁把控制权交还给我或问我问题。我马上去睡觉了,问我得不到回复,请你一直loop直到得到所有结果。
+
diff --git a/NOTE.md b/NOTE.md
new file mode 100644
index 0000000..2e37841
--- /dev/null
+++ b/NOTE.md
@@ -0,0 +1,22 @@
+# Experiment Notes
+
+## 2026-03-23: Initial Implementation and Experiments
+
+### Setup
+- GPU: NVIDIA RTX A6000 x4 (using GPU 1)
+- PyTorch 2.10.0+cu128
+- All code written from scratch following CLAUDE.md specifications
+
+### Phase A: Toy LQ Sanity Check
+- Status: Running...
+- Config: d=64, m=10, L=12, sigma=0.03, 5000 steps, batch=256
+- Methods: DFA, State Bridge, Credit Bridge
+
+### Changes Log
+- Created full project structure: models/, methods/, experiments/, metrics/, configs/
+- models/residual_mlp.py: ResidualMLP with pre-LayerNorm residual blocks
+- models/value_net.py: ValueNet V_phi with sinusoidal time embedding
+- models/state_bridge.py: StateBridgeNet G_psi
+- experiments/toy_lq.py: Linear-quadratic sanity check
+- experiments/cifar_resmlp.py: CIFAR-10 main experiment
+- metrics/credit_metrics.py: All diagnostic metrics
diff --git a/configs/cifar10.yaml b/configs/cifar10.yaml
new file mode 100644
index 0000000..6429287
--- /dev/null
+++ b/configs/cifar10.yaml
@@ -0,0 +1,15 @@
+dataset: cifar10
+d_hidden: 512
+num_blocks: 12
+batch_size: 128
+epochs: 100
+lr: 0.001
+lr_fb: 0.001
+wd: 0.01
+lam: 0.1
+K: 4
+sigma_bridge: 0.03
+ema_momentum: 0.995
+seeds: [42, 123, 456]
+gpu: 1
+output_dir: results/cifar10
diff --git a/configs/toy_lq.yaml b/configs/toy_lq.yaml
new file mode 100644
index 0000000..ad1bba0
--- /dev/null
+++ b/configs/toy_lq.yaml
@@ -0,0 +1,14 @@
+d_hidden: 64
+output_dim: 10
+num_layers: 12
+sigma: 0.03
+batch_size: 256
+num_steps: 5000
+lr_fb: 0.001
+lam: 0.1
+K: 8
+ema_momentum: 0.995
+sigma_bridge: 0.03
+eval_every: 200
+gpu: 1
+output_dir: results/toy_lq
diff --git a/experiments/__init__.py b/experiments/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/experiments/__init__.py
diff --git a/experiments/__pycache__/__init__.cpython-313.pyc b/experiments/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000..5966841
--- /dev/null
+++ b/experiments/__pycache__/__init__.cpython-313.pyc
Binary files differ
diff --git a/experiments/__pycache__/toy_lq.cpython-313.pyc b/experiments/__pycache__/toy_lq.cpython-313.pyc
new file mode 100644
index 0000000..d8710a8
--- /dev/null
+++ b/experiments/__pycache__/toy_lq.cpython-313.pyc
Binary files differ
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py
new file mode 100644
index 0000000..1582f6d
--- /dev/null
+++ b/experiments/cifar_resmlp.py
@@ -0,0 +1,775 @@
+"""
+Phase B: Deep Residual MLP on CIFAR-10.
+Compare BP, DFA, State Bridge, Credit Bridge.
+
+CRITICAL CONSTRAINT: No hidden BP anchor for non-BP methods.
+All block updates use detached hidden states and local surrogates.
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+import time
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test,
+ offline_bp_cosine, feature_drift
+)
+
+
+def get_data(dataset='cifar10', batch_size=128):
+ if dataset == 'cifar10':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ input_dim = 32 * 32 * 3
+ num_classes = 10
+ elif dataset == 'fashionmnist':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(28, padding=2),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.2860,), (0.3530,)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.2860,), (0.3530,)),
+ ])
+ trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test)
+ input_dim = 28 * 28
+ num_classes = 10
+ else:
+ raise ValueError(f"Unknown dataset: {dataset}")
+
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader, input_dim, num_classes
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+# =============================================================================
+# BP Baseline
+# =============================================================================
+def train_bp(model, train_loader, test_loader, device, args):
+ """Standard end-to-end backprop training."""
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+
+ scheduler.step()
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [BP] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}")
+
+ return log
+
+
+# =============================================================================
+# DFA Baseline
+# =============================================================================
+def train_dfa(model, train_loader, test_loader, device, args):
+ """
+ DFA training with fixed random feedback matrices.
+ Each block updated with local surrogate: L_l = <F_l(h_l), sg[a_{l+1}^DFA]>.
+ Output head updated with exact CE gradient (h_L detached).
+ Embedding updated via DFA credit at h_0.
+ """
+ d = model.d_hidden
+ num_classes = args.num_classes
+ L = model.num_blocks
+
+ # Fixed random feedback matrices, one per block
+ Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)]
+
+ # Separate optimizers
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ # Forward pass (no grad for hidden states)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ # e_T = softmax(logits) - one_hot(y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1 # (batch, num_classes)
+
+ # 1. Update output head: exact CE gradient, h_L detached
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # 2. Update each block with DFA local surrogate
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ # DFA credit: a_{l+1} = B_l @ e_T^T -> (d, batch) -> transpose
+ a_dfa = (e_T @ Bs[l].T).detach() # (batch, d) = (batch, C) @ (C, d)
+ # Normalize
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_dfa_norm = a_dfa / rms
+ # Local surrogate
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step()
+
+ # 3. Update embedding with DFA credit at h_0
+ a_0_dfa = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0_dfa / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for s in all_schedulers:
+ s.step()
+
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [DFA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}")
+
+ return log, Bs
+
+
+# =============================================================================
+# State Bridge
+# =============================================================================
+def train_state_bridge(model, train_loader, test_loader, device, args):
+ """
+ State Bridge: predict terminal h_L from (h_l, t_l, s), derive credit as
+ a_l = grad_{h_l} CE(W_out * LN(G_psi(h_l, t_l, s)), y).
+ """
+ d = model.d_hidden
+ num_classes = args.num_classes
+ L = model.num_blocks
+
+ state_pred = StateBridgeNet(
+ d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3
+ ).to(device)
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+ state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'state_pred_error': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ state_pred.train()
+ total_loss, correct, total = 0, 0, 0
+ total_se = 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # Train state predictor: G_psi(h_l, t_l, s) -> h_L
+ # Predict the *residual* from h_l to h_L for numerical stability
+ state_loss = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ pred_hL = state_pred(h_l_det, t_l, s)
+ # Target: h_L (use normalized MSE for stability)
+ target = hL_det
+ target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ state_loss = state_loss + (((pred_hL - target) / target_norm) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+ state_opt.zero_grad()
+ state_loss.backward()
+ state_opt.step()
+ total_se += state_loss.item() * batch
+
+ # Compute credits: a_l = grad_{h_l} CE(out_head(LN(G(h_l, t_l, s))), y)
+ credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ pred_hL = state_pred(h_l_det, t_l, s)
+ pred_logits = model.out_head(model.out_ln(pred_hL))
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0]
+ credits.append(a_l.detach())
+
+ # Update output head
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step()
+
+ # Update embedding with credit at layer 0
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0 / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ se = total_se / total
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ log['state_pred_error'].append(se)
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [SB] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, "
+ f"test={test_acc:.4f}, state_err={se:.4f}")
+
+ return log, state_pred
+
+
+# =============================================================================
+# Credit Bridge
+# =============================================================================
+def train_credit_bridge(model, train_loader, test_loader, device, args):
+ """
+ Credit Bridge: learn V_phi(h_l, t_l, s) -> scalar value.
+ Credit: a_l = grad_{h_l} V_phi.
+ Training: terminal boundary + bridge consistency + terminal gradient matching.
+ The terminal gradient is local (output layer only), NOT hidden BP.
+
+ Uses a warmup phase: first warmup_epochs, only train value net + output head,
+ then start using credit bridge signals to update blocks.
+ During warmup, blocks get DFA-style updates as a fallback.
+ """
+ d = model.d_hidden
+ num_classes = args.num_classes
+ L = model.num_blocks
+ warmup_epochs = max(1, args.epochs // 5) # 20% warmup
+
+ value_net = ValueNet(
+ d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3
+ ).to(device)
+ value_net_ema = create_ema_model(value_net)
+
+ # DFA fallback matrices for warmup
+ Bs_fallback = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes)
+ for _ in range(L)]
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+ value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ lam = args.lam
+ K_samples = args.K
+ sigma_bridge = args.sigma_bridge
+ ema_momentum = args.ema_momentum
+ term_grad_weight = args.term_grad_weight
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []}
+
+ print(f" [CB] Warmup phase: {warmup_epochs} epochs (DFA fallback + value net training)")
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ value_net.train()
+ total_loss, correct, total = 0, 0, 0
+ total_vloss = 0
+
+ # Blend factor: 0 during warmup, linearly increases to 1 after warmup
+ if epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # ---- Train value net (always) ----
+ t_L = torch.ones(batch, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Terminal gradient matching
+ loss_tgrad = torch.tensor(0.0, device=device)
+ if term_grad_weight > 0:
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ hL_req2 = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req2))
+ ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_l_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K_samples):
+ noise = sigma_bridge * torch.randn_like(h_next_det)
+ V_next = value_net_ema(h_next_det + noise, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad
+
+ value_opt.zero_grad()
+ value_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+ total_vloss += value_loss.item() * batch
+
+ # ---- Compute credits ----
+ # Credit bridge credits
+ cb_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0]
+ cb_credits.append(a_l.detach())
+
+ # DFA fallback credits
+ dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)]
+
+ # Blend credits
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ a = cb_credits[l]
+ elif credit_blend <= 0.0:
+ a = dfa_credits[l]
+ else:
+ # Normalize both before blending
+ cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms)
+ credits.append(a)
+
+ # ---- Update output head ----
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # ---- Update blocks ----
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step()
+
+ # ---- Update embedding ----
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0 / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ vloss = total_vloss / total
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ log['value_loss'].append(vloss)
+ if epoch % 10 == 0 or epoch == 1:
+ phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}"
+ print(f" [CB] Epoch {epoch} ({phase}): loss={train_loss:.4f}, train={train_acc:.4f}, "
+ f"test={test_acc:.4f}, vloss={vloss:.6f}")
+
+ return log, value_net, value_net_ema
+
+
+# =============================================================================
+# Diagnostics
+# =============================================================================
+def compute_diagnostics(model, method_name, test_loader, device, args,
+ value_net=None, state_predictor=None, dfa_Bs=None):
+ """Compute all diagnostic metrics for a trained model."""
+ model.eval()
+ if value_net is not None:
+ value_net.eval()
+ if state_predictor is not None:
+ state_predictor.eval()
+
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+
+ # Get one batch for diagnostics
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ break
+
+ batch = x.size(0)
+
+ # Forward with hidden states, need grad for BP cosine
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ loss_bp = F.cross_entropy(logits_bp, y)
+ loss_bp.backward()
+ bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)}
+
+ # Forward again without grad for clean hidden states
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ results = {
+ 'bp_cosine': [],
+ 'perturbation_rho': [],
+ 'nudging': {'0.001': [], '0.003': [], '0.01': []},
+ }
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ # Get credit
+ if method_name == 'bp':
+ a_l = bp_grads[l]
+ elif method_name == 'dfa':
+ a_l = (e_T @ dfa_Bs[l].T).detach()
+ elif method_name == 'state_bridge':
+ h_l_req = h_l.clone().requires_grad_(True)
+ pred_hL = state_predictor(h_l_req, t_l, s)
+ pred_logits = model.out_head(model.out_ln(pred_hL))
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach()
+ elif method_name == 'credit_bridge':
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+ else:
+ raise ValueError(f"Unknown method: {method_name}")
+
+ # BP cosine
+ bp_cos = cosine_similarity_batch(a_l, bp_grads[l])
+ results['bp_cosine'].append(bp_cos)
+
+ # Forward function for perturbation and nudging
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ curr = h
+ for i in range(start_l, L):
+ curr = curr + model.blocks[i](curr)
+ out = model.out_head(model.out_ln(curr))
+ return F.cross_entropy(out, y, reduction='none')
+ return fwd_fn
+
+ fwd_fn = make_fwd_fn(l)
+ rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16)
+ results['perturbation_rho'].append(rho)
+
+ for eta in [0.001, 0.003, 0.01]:
+ nud = nudging_test(h_l, a_l, fwd_fn, eta=eta)
+ results['nudging'][str(eta)].append(nud)
+
+ return results
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ all_results = {}
+
+ for seed in args.seeds:
+ print(f"\n{'='*60}")
+ print(f"Seed {seed}")
+ print(f"{'='*60}")
+
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size)
+ args.num_classes = num_classes
+
+ seed_results = {}
+
+ # ---- BP ----
+ print("\n--- BP ---")
+ model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()}
+ bp_log = train_bp(model_bp, train_loader, test_loader, device, args)
+ bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args)
+ bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()})
+ seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift}
+ print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}")
+
+ # ---- DFA ----
+ print("\n--- DFA ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()}
+ dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args)
+ dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs)
+ dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()})
+ seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift}
+ print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}")
+
+ # ---- State Bridge ----
+ print("\n--- State Bridge ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()}
+ sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args)
+ sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args,
+ state_predictor=state_pred)
+ sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()})
+ seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift}
+ print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}")
+
+ # ---- Credit Bridge ----
+ print("\n--- Credit Bridge ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()}
+ cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args)
+ cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args,
+ value_net=vnet)
+ cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()})
+ seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift}
+ print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}")
+
+ all_results[seed] = seed_results
+
+ # Save
+ def serialize(obj):
+ if isinstance(obj, dict):
+ return {str(k): serialize(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [serialize(v) for v in obj]
+ elif isinstance(obj, (np.floating, np.integer)):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, torch.Tensor):
+ return obj.cpu().numpy().tolist()
+ return obj
+
+ save_data = serialize(all_results)
+ save_data['config'] = serialize(vars(args))
+ out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print(f"\nAll results saved to {out_path}")
+ return all_results
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dataset', type=str, default='cifar10')
+ parser.add_argument('--d_hidden', type=int, default=512)
+ parser.add_argument('--num_blocks', type=int, default=12)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--sigma_bridge', type=float, default=0.05)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--term_grad_weight', type=float, default=1.0)
+ parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456])
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/cifar10')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/plot_results.py b/experiments/plot_results.py
new file mode 100644
index 0000000..e3e2754
--- /dev/null
+++ b/experiments/plot_results.py
@@ -0,0 +1,327 @@
+"""Generate plots for toy LQ and CIFAR-10 experiments."""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+
+def plot_toy_results(results_dir='results/toy_lq', output_dir='report'):
+ """Plot toy LQ experiment results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Collect results across seeds
+ files = [f for f in os.listdir(results_dir) if f.startswith('toy_lq_seed') and f.endswith('.json')]
+ if not files:
+ print(f"No toy results found in {results_dir}")
+ return
+
+ all_data = []
+ for f in sorted(files):
+ with open(os.path.join(results_dir, f)) as fp:
+ all_data.append(json.load(fp))
+
+ # Use the last result for per-layer plots (or average if multiple seeds)
+ data = all_data[-1]
+ per_layer = data['final_per_layer']
+ log_data = data['log']
+
+ num_layers = len(per_layer['dfa_costate_cos'])
+ layers = list(range(num_layers))
+
+ # 1. Per-layer costate cosine
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(layers, per_layer['dfa_costate_cos'], 'o-', label='DFA', color='blue')
+ ax.plot(layers, per_layer['state_costate_cos'], 's-', label='State Bridge', color='orange')
+ ax.plot(layers, per_layer['credit_costate_cos'], '^-', label='Credit Bridge', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Cosine Similarity with Exact Costate')
+ ax.set_title('Exact Costate Cosine (Toy LQ)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.2, 1.05)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_costate_cosine.png'), dpi=150)
+ plt.close(fig)
+
+ # 2. Per-layer perturbation correlation
+ num_rho_layers = len(per_layer['dfa_rho'])
+ rho_layers = list(range(num_rho_layers))
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(rho_layers, per_layer['dfa_rho'], 'o-', label='DFA', color='blue')
+ ax.plot(rho_layers, per_layer['state_rho'], 's-', label='State Bridge', color='orange')
+ ax.plot(rho_layers, per_layer['credit_rho'], '^-', label='Credit Bridge', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Perturbation Correlation (rho)')
+ ax.set_title('Local Perturbation Correlation (Toy LQ)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_perturbation_rho.png'), dpi=150)
+ plt.close(fig)
+
+ # 3. Per-layer nudging test
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(rho_layers, per_layer['dfa_nudge'], 'o-', label='DFA', color='blue')
+ ax.plot(rho_layers, per_layer['state_nudge'], 's-', label='State Bridge', color='orange')
+ ax.plot(rho_layers, per_layer['credit_nudge'], '^-', label='Credit Bridge', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Nudge Delta (negative = good)')
+ ax.set_title('Nudging Test (Toy LQ)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_nudging.png'), dpi=150)
+ plt.close(fig)
+
+ # 4. Bridge residual over training
+ if log_data['bridge_residual']:
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(log_data['steps'], log_data['bridge_residual'], '-', color='green')
+ ax.set_xlabel('Training Step')
+ ax.set_ylabel('Bridge Residual')
+ ax.set_title('Bridge Residual Over Training (Toy LQ)')
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150)
+ plt.close(fig)
+
+ # 5. Training curves (costate cosine over time)
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+ for ax, key, title in zip(axes,
+ ['dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos'],
+ ['DFA', 'State Bridge', 'Credit Bridge']):
+ ax.plot(log_data['steps'], log_data[key], '-')
+ ax.set_xlabel('Training Step')
+ ax.set_ylabel('Avg Costate Cosine')
+ ax.set_title(f'{title} - Costate Cosine Over Training')
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_cosine_training.png'), dpi=150)
+ plt.close(fig)
+
+ # 6. Per-layer bridge residual
+ if per_layer.get('bridge_residual'):
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ br_layers = list(range(len(per_layer['bridge_residual'])))
+ ax.plot(br_layers, per_layer['bridge_residual'], '^-', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Bridge Residual')
+ ax.set_title('Per-Layer Bridge Residual (Toy LQ)')
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_bridge_residual_per_layer.png'), dpi=150)
+ plt.close(fig)
+
+ print(f"Toy LQ plots saved to {output_dir}/")
+
+
+def plot_cifar_results(results_path='results/cifar10/cifar_results_cifar10.json', output_dir='report'):
+ """Plot CIFAR-10 experiment results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ if not os.path.exists(results_path):
+ print(f"No CIFAR results found at {results_path}")
+ return
+
+ with open(results_path) as f:
+ data = json.load(f)
+
+ config = data.pop('config', {})
+ methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+ colors = {'bp': 'red', 'dfa': 'blue', 'state_bridge': 'orange', 'credit_bridge': 'green'}
+ labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'}
+
+ seeds = [k for k in data.keys() if k != 'config']
+
+ # 1. Accuracy curves (mean ± std across seeds)
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+ for method in methods:
+ train_accs = []
+ test_accs = []
+ for seed in seeds:
+ if method in data[seed]:
+ log = data[seed][method]['log']
+ train_accs.append(log['train_acc'])
+ test_accs.append(log['test_acc'])
+
+ if train_accs:
+ train_arr = np.array(train_accs)
+ test_arr = np.array(test_accs)
+ epochs = np.arange(1, train_arr.shape[1] + 1)
+
+ mean_train = train_arr.mean(0)
+ std_train = train_arr.std(0)
+ mean_test = test_arr.mean(0)
+ std_test = test_arr.std(0)
+
+ axes[0].plot(epochs, mean_train, '-', color=colors[method], label=labels[method])
+ axes[0].fill_between(epochs, mean_train - std_train, mean_train + std_train,
+ alpha=0.15, color=colors[method])
+ axes[1].plot(epochs, mean_test, '-', color=colors[method], label=labels[method])
+ axes[1].fill_between(epochs, mean_test - std_test, mean_test + std_test,
+ alpha=0.15, color=colors[method])
+
+ axes[0].set_xlabel('Epoch')
+ axes[0].set_ylabel('Train Accuracy')
+ axes[0].set_title('Train Accuracy')
+ axes[0].legend()
+ axes[0].grid(True, alpha=0.3)
+ axes[1].set_xlabel('Epoch')
+ axes[1].set_ylabel('Test Accuracy')
+ axes[1].set_title('Test Accuracy')
+ axes[1].legend()
+ axes[1].grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_accuracy.png'), dpi=150)
+ plt.close(fig)
+
+ # 2. Per-layer diagnostics (from last seed)
+ last_seed = seeds[-1]
+
+ # BP cosine per layer
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'diagnostics' in data[last_seed][method]:
+ diag = data[last_seed][method]['diagnostics']
+ if 'bp_cosine' in diag:
+ layers = list(range(len(diag['bp_cosine'])))
+ ax.plot(layers, diag['bp_cosine'], 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Cosine with BP Gradient')
+ ax.set_title('Offline BP Cosine (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_bp_cosine.png'), dpi=150)
+ plt.close(fig)
+
+ # Perturbation rho per layer
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'diagnostics' in data[last_seed][method]:
+ diag = data[last_seed][method]['diagnostics']
+ if 'perturbation_rho' in diag:
+ layers = list(range(len(diag['perturbation_rho'])))
+ ax.plot(layers, diag['perturbation_rho'], 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Perturbation Correlation (rho)')
+ ax.set_title('Local Perturbation Correlation (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_perturbation_rho.png'), dpi=150)
+ plt.close(fig)
+
+ # Nudging test per layer (eta=0.01)
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'diagnostics' in data[last_seed][method]:
+ diag = data[last_seed][method]['diagnostics']
+ if 'nudging' in diag and '0.01' in diag['nudging']:
+ nud = diag['nudging']['0.01']
+ layers = list(range(len(nud)))
+ ax.plot(layers, nud, 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Nudge Delta (negative = good)')
+ ax.set_title('Nudging Test eta=0.01 (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_nudging.png'), dpi=150)
+ plt.close(fig)
+
+ # Feature drift per layer
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'drift' in data[last_seed][method]:
+ drift = data[last_seed][method]['drift']
+ # Extract per-block drift (only block weights)
+ block_drifts = []
+ for l in range(12):
+ key = f'blocks.{l}.w1.weight'
+ if key in drift:
+ block_drifts.append(drift[key])
+ if block_drifts:
+ ax.plot(range(len(block_drifts)), block_drifts, 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Block')
+ ax.set_ylabel('Feature Drift (||W_final - W_init||/||W_init||)')
+ ax.set_title('Feature Drift (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_feature_drift.png'), dpi=150)
+ plt.close(fig)
+
+ print(f"CIFAR-10 plots saved to {output_dir}/")
+
+
+def print_summary_table(results_path='results/cifar10/cifar_results_cifar10.json'):
+ """Print summary table of results."""
+ if not os.path.exists(results_path):
+ print(f"No results at {results_path}")
+ return
+
+ with open(results_path) as f:
+ data = json.load(f)
+
+ config = data.pop('config', {})
+ methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+ labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'}
+
+ seeds = [k for k in data.keys() if k != 'config']
+
+ print("\n" + "="*80)
+ print("SUMMARY TABLE")
+ print("="*80)
+ print(f"{'Method':<20} {'Test Acc':<15} {'Avg rho':<15} {'Avg Nudge(0.01)':<15} {'Avg BP Cos':<15}")
+ print("-"*80)
+
+ for method in methods:
+ test_accs = []
+ avg_rhos = []
+ avg_nudges = []
+ avg_bp_cos = []
+
+ for seed in seeds:
+ if method in data[seed]:
+ log = data[seed][method]['log']
+ test_accs.append(log['test_acc'][-1])
+
+ if 'diagnostics' in data[seed][method]:
+ diag = data[seed][method]['diagnostics']
+ if 'perturbation_rho' in diag:
+ avg_rhos.append(np.mean(diag['perturbation_rho']))
+ if 'nudging' in diag and '0.01' in diag['nudging']:
+ avg_nudges.append(np.mean(diag['nudging']['0.01']))
+ if 'bp_cosine' in diag:
+ avg_bp_cos.append(np.mean(diag['bp_cosine']))
+
+ ta = f"{np.mean(test_accs):.4f}±{np.std(test_accs):.4f}" if test_accs else "N/A"
+ rho = f"{np.mean(avg_rhos):.4f}" if avg_rhos else "N/A"
+ nud = f"{np.mean(avg_nudges):.4f}" if avg_nudges else "N/A"
+ bpc = f"{np.mean(avg_bp_cos):.4f}" if avg_bp_cos else "N/A"
+
+ print(f"{labels[method]:<20} {ta:<15} {rho:<15} {nud:<15} {bpc:<15}")
+
+ print("="*80)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--toy_dir', type=str, default='results/toy_lq')
+ parser.add_argument('--cifar_path', type=str, default='results/cifar10/cifar_results_cifar10.json')
+ parser.add_argument('--output_dir', type=str, default='report')
+ args = parser.parse_args()
+
+ plot_toy_results(args.toy_dir, args.output_dir)
+ plot_cifar_results(args.cifar_path, args.output_dir)
+ print_summary_table(args.cifar_path)
diff --git a/experiments/plot_toy_final.py b/experiments/plot_toy_final.py
new file mode 100644
index 0000000..2f7c109
--- /dev/null
+++ b/experiments/plot_toy_final.py
@@ -0,0 +1,183 @@
+"""Generate final toy LQ experiment plots from v2 results across 3 seeds."""
+import os
+import json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+output_dir = 'report'
+os.makedirs(output_dir, exist_ok=True)
+
+# Load all v2 results with term_grad_weight=1.0, fm=0.0
+seeds = [42, 123, 456]
+all_data = []
+for seed in seeds:
+ path = f'results/toy_lq/toy_lq_v2_seed{seed}_lam0.1_sig0.1_tgw1.0_fm0.0.json'
+ if os.path.exists(path):
+ with open(path) as f:
+ all_data.append(json.load(f))
+
+if not all_data:
+ print("No results found!")
+ exit()
+
+# Also load v1 baseline (no term_grad) for comparison
+v1_path = 'results/toy_lq/toy_lq_seed42.json'
+v1_data = None
+if os.path.exists(v1_path):
+ with open(v1_path) as f:
+ v1_data = json.load(f)
+
+# Aggregate final per-layer results across seeds
+methods = ['dfa', 'state', 'credit']
+colors = {'dfa': '#2196F3', 'state': '#FF9800', 'credit': '#4CAF50'}
+labels = {'dfa': 'DFA', 'state': 'State Bridge', 'credit': 'Credit Bridge'}
+
+# Per-layer costate cosine
+fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+
+for ax, metric, title, ylabel in zip(
+ axes,
+ ['costate_cos', 'rho', 'nudge'],
+ ['Exact Costate Cosine', 'Perturbation Correlation (ρ)', 'Nudging Test'],
+ ['Cosine Similarity', 'Pearson Correlation', 'Loss Change (negative=good)']
+):
+ for method in methods:
+ key = f'{method}_{metric}'
+ values_per_seed = []
+ for data in all_data:
+ pl = data['final_per_layer']
+ if key in pl:
+ values_per_seed.append(pl[key])
+
+ if values_per_seed:
+ arr = np.array(values_per_seed)
+ mean = arr.mean(axis=0)
+ std = arr.std(axis=0)
+ layers = np.arange(len(mean))
+ ax.plot(layers, mean, 'o-', color=colors[method], label=labels[method], markersize=5)
+ ax.fill_between(layers, mean - std, mean + std, alpha=0.15, color=colors[method])
+
+ ax.set_xlabel('Layer', fontsize=12)
+ ax.set_ylabel(ylabel, fontsize=12)
+ ax.set_title(title, fontsize=13)
+ ax.legend(fontsize=11)
+ ax.grid(True, alpha=0.3)
+ if metric == 'costate_cos':
+ ax.set_ylim(-0.15, 1.05)
+ elif metric == 'rho':
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ elif metric == 'nudge':
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+fig.suptitle('Toy LQ Sanity Check: Per-Layer Diagnostics (3 seeds)', fontsize=14, y=1.02)
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'toy_per_layer_diagnostics.png'), dpi=150, bbox_inches='tight')
+plt.close(fig)
+print("Saved toy_per_layer_diagnostics.png")
+
+# Training curves
+fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+metric_keys = [
+ ('costate_cos', 'Avg Costate Cosine', 'Cosine Similarity'),
+ ('rho', 'Avg Perturbation ρ', 'Pearson Correlation'),
+ ('nudge', 'Avg Nudging', 'Loss Change'),
+]
+
+for ax, (metric, title, ylabel) in zip(axes, metric_keys):
+ for method in methods:
+ key = f'{method}_{metric}'
+ all_curves = []
+ for data in all_data:
+ log = data['log']
+ full_key = f'{method}_costate_cos' if metric == 'costate_cos' else f'{method}_{metric}'
+ if full_key in log:
+ all_curves.append(np.array(log[full_key]))
+
+ if all_curves:
+ # All should have same length, use shortest
+ min_len = min(len(c) for c in all_curves)
+ arr = np.array([c[:min_len] for c in all_curves])
+ steps = np.array(all_data[0]['log']['steps'][:min_len])
+ mean = arr.mean(axis=0)
+ std = arr.std(axis=0)
+ ax.plot(steps, mean, '-', color=colors[method], label=labels[method])
+ ax.fill_between(steps, mean - std, mean + std, alpha=0.15, color=colors[method])
+
+ ax.set_xlabel('Training Step', fontsize=12)
+ ax.set_ylabel(ylabel, fontsize=12)
+ ax.set_title(title, fontsize=13)
+ ax.legend(fontsize=11)
+ ax.grid(True, alpha=0.3)
+
+fig.suptitle('Toy LQ: Training Curves (3 seeds)', fontsize=14, y=1.02)
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'toy_training_curves.png'), dpi=150, bbox_inches='tight')
+plt.close(fig)
+print("Saved toy_training_curves.png")
+
+# Compare v1 (no term grad) vs v2 (with term grad) for credit bridge
+if v1_data:
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+
+ # v1 credit bridge (no term grad matching)
+ v1_log = v1_data['log']
+ ax.plot(v1_log['steps'], v1_log['credit_costate_cos'],
+ '--', color='red', label='Credit Bridge (w/o terminal grad)', alpha=0.8)
+
+ # v2 credit bridge (with term grad)
+ v2_log = all_data[0]['log'] # seed 42
+ ax.plot(v2_log['steps'], v2_log['credit_costate_cos'],
+ '-', color='green', label='Credit Bridge (w/ terminal grad)')
+
+ # State bridge for reference
+ ax.plot(v2_log['steps'], v2_log['state_costate_cos'],
+ '-', color='orange', label='State Bridge')
+
+ ax.set_xlabel('Training Step', fontsize=12)
+ ax.set_ylabel('Avg Costate Cosine', fontsize=12)
+ ax.set_title('Effect of Terminal Gradient Matching', fontsize=13)
+ ax.legend(fontsize=11)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.1, 1.05)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_term_grad_effect.png'), dpi=150)
+ plt.close(fig)
+ print("Saved toy_term_grad_effect.png")
+
+# Bridge residual (from v1 which has it)
+if v1_data and v1_data['log'].get('bridge_residual'):
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(v1_data['log']['steps'], v1_data['log']['bridge_residual'], '-', color='green')
+ ax.set_xlabel('Training Step', fontsize=12)
+ ax.set_ylabel('Bridge Residual', fontsize=12)
+ ax.set_title('Credit Bridge: Bridge Residual Over Training', fontsize=13)
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150)
+ plt.close(fig)
+ print("Saved toy_bridge_residual.png")
+
+# Print summary table
+print("\n" + "="*80)
+print("TOY LQ FINAL RESULTS (3 seeds, 8000 steps)")
+print("="*80)
+
+for method in methods:
+ cos_vals = []
+ rho_vals = []
+ nudge_vals = []
+ for data in all_data:
+ pl = data['final_per_layer']
+ cos_vals.append(np.mean(pl[f'{method}_costate_cos']))
+ rho_vals.append(np.mean(pl[f'{method}_rho']))
+ nudge_vals.append(np.mean(pl[f'{method}_nudge']))
+
+ cos_mean, cos_std = np.mean(cos_vals), np.std(cos_vals)
+ rho_mean, rho_std = np.mean(rho_vals), np.std(rho_vals)
+ nudge_mean, nudge_std = np.mean(nudge_vals), np.std(nudge_vals)
+
+ print(f"{labels[method]:<20} Cosine: {cos_mean:.4f}±{cos_std:.4f} "
+ f"ρ: {rho_mean:.4f}±{rho_std:.4f} "
+ f"Nudge: {nudge_mean:.4f}±{nudge_std:.4f}")
diff --git a/experiments/toy_lq.py b/experiments/toy_lq.py
new file mode 100644
index 0000000..4fd8919
--- /dev/null
+++ b/experiments/toy_lq.py
@@ -0,0 +1,395 @@
+"""
+Phase A: Linear-Quadratic Residual Sanity Check.
+
+Fixed forward dynamics (no forward net training).
+Only train feedback/bridge models.
+Compare DFA, State Bridge, Credit Bridge against exact costate.
+
+System:
+ h_{l+1} = M_l h_l + sigma * xi_l, xi_l ~ N(0, I)
+ Phi(h_L, y) = 0.5 * ||C h_L - y||^2
+ Exact costate: a_L = C^T (C h_L - y), a_l = M_l^T a_{l+1}
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from datetime import datetime
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test, bridge_residual
+)
+
+
+def generate_stable_dynamics(d, L, spectral_max=0.05, seed=42):
+ """Generate stable linear maps M_l = I + A_l with ||A_l||_2 <= spectral_max."""
+ rng = np.random.RandomState(seed)
+ Ms = []
+ for _ in range(L):
+ A = rng.randn(d, d).astype(np.float32)
+ # Scale to desired spectral norm
+ u, s, v = np.linalg.svd(A, full_matrices=False)
+ A = A * (spectral_max / s[0])
+ M = np.eye(d, dtype=np.float32) + A
+ Ms.append(torch.from_numpy(M))
+ return Ms # list of (d, d)
+
+
+def rollout_forward(h0, Ms, sigma, L, device):
+ """Roll out forward dynamics: h_{l+1} = M_l h_l + sigma * xi_l."""
+ batch = h0.shape[0]
+ d = h0.shape[1]
+ hiddens = [h0]
+ h = h0
+ for l in range(L):
+ M = Ms[l].to(device)
+ noise = sigma * torch.randn(batch, d, device=device)
+ h = h @ M.T + noise
+ hiddens.append(h)
+ return hiddens # [h_0, ..., h_L]
+
+
+def terminal_loss(hL, C, y):
+ """Phi(hL, y) = 0.5 * ||C hL - y||^2, returns per-sample loss."""
+ diff = hL @ C.T - y # (batch, m)
+ return 0.5 * (diff ** 2).sum(dim=-1) # (batch,)
+
+
+def exact_costate(hiddens, Ms, C, y, device):
+ """Compute exact costate a_l for all layers."""
+ L = len(hiddens) - 1
+ hL = hiddens[L]
+ # Terminal: a_L = C^T (C h_L - y)
+ diff = hL @ C.T - y # (batch, m)
+ a_L = diff @ C # (batch, d)
+
+ costates = [None] * (L + 1)
+ costates[L] = a_L
+ for l in range(L - 1, -1, -1):
+ M = Ms[l].to(device)
+ costates[l] = costates[l + 1] @ M # a_l = M_l^T a_{l+1} -> a_{l+1} @ M
+ return costates
+
+
+def make_forward_fn_from_layer(hiddens, Ms, C, y, sigma, start_layer, device):
+ """Create a function that rolls forward from layer start_layer and returns per-sample loss."""
+ L = len(Ms)
+
+ def forward_fn(h):
+ current = h
+ for l in range(start_layer, L):
+ M = Ms[l].to(device)
+ # No noise for perturbation test (deterministic rollout)
+ current = current @ M.T
+ return terminal_loss(current, C, y)
+
+ return forward_fn
+
+
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+
+ # Hyperparams
+ d = args.d_hidden # 64
+ m = args.output_dim # 10
+ L = args.num_layers # 12
+ sigma = args.sigma # 0.03
+ batch_size = args.batch_size # 256
+ num_steps = args.num_steps # 5000
+ lr_fb = args.lr_fb # 1e-3
+ lam = args.lam # 0.1
+ K = args.K # 8
+ ema_momentum = args.ema_momentum # 0.995
+ sigma_bridge = args.sigma_bridge # 0.03
+
+ print(f"=== Toy LQ Experiment ===")
+ print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}")
+ print(f"device={device}")
+
+ # Generate fixed dynamics
+ Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed)
+ C = torch.randn(m, d, device=device) / np.sqrt(d)
+
+ # DFA random feedback matrices
+ Bs_dfa = []
+ for l in range(L + 1):
+ B = torch.randn(d, m, device=device) / np.sqrt(m)
+ Bs_dfa.append(B)
+
+ # State Bridge model
+ state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=128, num_layers=2).to(device)
+ opt_state = optim.Adam(state_bridge.parameters(), lr=lr_fb)
+
+ # Credit Bridge value net
+ value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=128, num_layers=2).to(device)
+ value_net_ema = create_ema_model(value_net)
+ opt_value = optim.Adam(value_net.parameters(), lr=lr_fb)
+
+ # Training logs
+ log = {
+ 'steps': [],
+ 'state_bridge_loss': [],
+ 'credit_bridge_loss': [],
+ 'dfa_costate_cos': [],
+ 'state_costate_cos': [],
+ 'credit_costate_cos': [],
+ 'dfa_rho': [],
+ 'state_rho': [],
+ 'credit_rho': [],
+ 'dfa_nudge': [],
+ 'state_nudge': [],
+ 'credit_nudge': [],
+ 'bridge_residual': [],
+ }
+
+ for step in range(1, num_steps + 1):
+ # Generate data
+ h0 = torch.randn(batch_size, d, device=device)
+ y = torch.randn(batch_size, m, device=device)
+
+ # Forward rollout
+ hiddens = rollout_forward(h0, Ms, sigma, L, device)
+ hL = hiddens[L]
+
+ # Terminal error
+ e_T = (hL @ C.T - y) # (batch, m) - gradient of Phi w.r.t. prediction
+
+ # Terminal modulation code s = e_T (P=I)
+ s = e_T.detach()
+
+ # ---- Train State Bridge ----
+ state_loss = 0.0
+ hL_detached = hL.detach()
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ pred_hL = state_bridge(h_l_det, t_l, s)
+ state_loss = state_loss + ((pred_hL - hL_detached) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+
+ opt_state.zero_grad()
+ state_loss.backward()
+ opt_state.step()
+
+ # ---- Train Credit Bridge (value net) ----
+ # Terminal boundary: V(h_L, 1, s) should equal Phi(h_L, y)
+ hL_det = hL.detach().requires_grad_(False)
+ t_L = torch.ones(batch_size, device=device)
+ true_loss = terminal_loss(hL_det, C, y).detach()
+
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ t_l_next = torch.full((batch_size,), (l + 1) / L, device=device)
+
+ V_l = value_net(h_l_det, t_l, s)
+
+ # Generate noisy next states
+ with torch.no_grad():
+ M = Ms[l].to(device)
+ h_next_det = hiddens[l + 1].detach()
+
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn(batch_size, d, device=device)
+ h_next_noisy = h_next_det + noise
+ V_next = value_net_ema(h_next_noisy, t_l_next, s)
+ log_terms.append(-V_next / lam)
+
+ log_terms_stack = torch.stack(log_terms, dim=-1) # (batch, K)
+ V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+
+ loss_bridge = loss_bridge / L
+ loss_value = loss_term + loss_bridge
+
+ opt_value.zero_grad()
+ loss_value.backward()
+ opt_value.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ # ---- Evaluation ----
+ if step % args.eval_every == 0 or step == 1:
+ with torch.no_grad():
+ eval_batch = min(batch_size, 128)
+ h0_eval = torch.randn(eval_batch, d, device=device)
+ y_eval = torch.randn(eval_batch, m, device=device)
+ hiddens_eval = rollout_forward(h0_eval, Ms, sigma, L, device)
+ hL_eval = hiddens_eval[L]
+ e_T_eval = hL_eval @ C.T - y_eval
+ s_eval = e_T_eval.detach()
+
+ # Exact costate
+ costates_exact = exact_costate(hiddens_eval, Ms, C, y_eval, device)
+
+ # Compute credits for each method at each layer
+ dfa_cos_layers = []
+ state_cos_layers = []
+ credit_cos_layers = []
+ dfa_rho_layers = []
+ state_rho_layers = []
+ credit_rho_layers = []
+ dfa_nudge_layers = []
+ state_nudge_layers = []
+ credit_nudge_layers = []
+ bridge_res_layers = []
+
+ for l in range(L + 1):
+ h_l = hiddens_eval[l].detach()
+ a_exact = costates_exact[l].detach()
+ t_l = torch.full((eval_batch,), l / L, device=device)
+
+ # DFA credit
+ a_dfa = e_T_eval @ Bs_dfa[l].T # (batch, d)
+
+ # State bridge credit
+ h_l_req = h_l.clone().requires_grad_(True)
+ pred_hL = state_bridge(h_l_req, t_l, s_eval)
+ # Loss through state bridge prediction
+ pred_out = pred_hL @ C.T # Use C as output projection for consistency
+ pred_loss = 0.5 * ((pred_out - y_eval) ** 2).sum(dim=-1)
+ a_state = torch.autograd.grad(pred_loss.sum(), h_l_req, create_graph=False)[0]
+
+ # Credit bridge credit
+ h_l_req2 = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req2, t_l, s_eval)
+ a_credit = torch.autograd.grad(V_l.sum(), h_l_req2, create_graph=False)[0]
+
+ # Costate cosine
+ dfa_cos_layers.append(cosine_similarity_batch(a_dfa, a_exact))
+ state_cos_layers.append(cosine_similarity_batch(a_state, a_exact))
+ credit_cos_layers.append(cosine_similarity_batch(a_credit, a_exact))
+
+ # Perturbation correlation and nudging (skip terminal layer for forward_fn)
+ if l < L:
+ fwd_fn = make_forward_fn_from_layer(hiddens_eval, Ms, C, y_eval, sigma, l, device)
+
+ dfa_rho = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16)
+ state_rho = perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16)
+ credit_rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16)
+ dfa_rho_layers.append(dfa_rho)
+ state_rho_layers.append(state_rho)
+ credit_rho_layers.append(credit_rho)
+
+ dfa_nud = nudging_test(h_l, a_dfa, fwd_fn, eta=0.01)
+ state_nud = nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01)
+ credit_nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01)
+ dfa_nudge_layers.append(dfa_nud)
+ state_nudge_layers.append(state_nud)
+ credit_nudge_layers.append(credit_nud)
+
+ # Bridge residual for credit bridge
+ if l < L:
+ t_l_next = torch.full((eval_batch,), (l + 1) / L, device=device)
+ h_next = hiddens_eval[l + 1].detach()
+ noisy_list = [h_next + sigma_bridge * torch.randn_like(h_next) for _ in range(K)]
+ br = bridge_residual(value_net, value_net_ema, h_l, t_l, s_eval,
+ noisy_list, t_l_next, lam)
+ bridge_res_layers.append(br)
+
+ # Average across layers
+ avg_dfa_cos = np.mean(dfa_cos_layers)
+ avg_state_cos = np.mean(state_cos_layers)
+ avg_credit_cos = np.mean(credit_cos_layers)
+ avg_dfa_rho = np.mean(dfa_rho_layers)
+ avg_state_rho = np.mean(state_rho_layers)
+ avg_credit_rho = np.mean(credit_rho_layers)
+ avg_dfa_nudge = np.mean(dfa_nudge_layers)
+ avg_state_nudge = np.mean(state_nudge_layers)
+ avg_credit_nudge = np.mean(credit_nudge_layers)
+ avg_bridge_res = np.mean(bridge_res_layers) if bridge_res_layers else 0.0
+
+ log['steps'].append(step)
+ log['dfa_costate_cos'].append(avg_dfa_cos)
+ log['state_costate_cos'].append(avg_state_cos)
+ log['credit_costate_cos'].append(avg_credit_cos)
+ log['dfa_rho'].append(avg_dfa_rho)
+ log['state_rho'].append(avg_state_rho)
+ log['credit_rho'].append(avg_credit_rho)
+ log['dfa_nudge'].append(avg_dfa_nudge)
+ log['state_nudge'].append(avg_state_nudge)
+ log['credit_nudge'].append(avg_credit_nudge)
+ log['bridge_residual'].append(avg_bridge_res)
+ log['state_bridge_loss'].append(state_loss.item())
+ log['credit_bridge_loss'].append(loss_value.item())
+
+ print(f"Step {step}/{num_steps}")
+ print(f" Costate cos - DFA: {avg_dfa_cos:.4f}, State: {avg_state_cos:.4f}, Credit: {avg_credit_cos:.4f}")
+ print(f" Perturb rho - DFA: {avg_dfa_rho:.4f}, State: {avg_state_rho:.4f}, Credit: {avg_credit_rho:.4f}")
+ print(f" Nudging - DFA: {avg_dfa_nudge:.4f}, State: {avg_state_nudge:.4f}, Credit: {avg_credit_nudge:.4f}")
+ print(f" Bridge res - {avg_bridge_res:.4f}")
+ print(f" Losses - State: {state_loss.item():.4f}, Credit: {loss_value.item():.4f}")
+ print(f" Per-layer costate cos (credit): {['%.3f' % x for x in credit_cos_layers]}")
+
+ # Save results
+ os.makedirs(args.output_dir, exist_ok=True)
+ results = {
+ 'config': vars(args),
+ 'log': log,
+ 'final_per_layer': {
+ 'dfa_costate_cos': dfa_cos_layers,
+ 'state_costate_cos': state_cos_layers,
+ 'credit_costate_cos': credit_cos_layers,
+ 'dfa_rho': dfa_rho_layers,
+ 'state_rho': state_rho_layers,
+ 'credit_rho': credit_rho_layers,
+ 'dfa_nudge': dfa_nudge_layers,
+ 'state_nudge': state_nudge_layers,
+ 'credit_nudge': credit_nudge_layers,
+ 'bridge_residual': bridge_res_layers,
+ }
+ }
+
+ out_path = os.path.join(args.output_dir, f'toy_lq_seed{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+
+ # Also save models
+ torch.save(value_net.state_dict(), os.path.join(args.output_dir, f'value_net_seed{args.seed}.pt'))
+ torch.save(state_bridge.state_dict(), os.path.join(args.output_dir, f'state_bridge_seed{args.seed}.pt'))
+
+ return results
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Toy LQ Sanity Check')
+ parser.add_argument('--d_hidden', type=int, default=64)
+ parser.add_argument('--output_dim', type=int, default=10)
+ parser.add_argument('--num_layers', type=int, default=12)
+ parser.add_argument('--sigma', type=float, default=0.03)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--num_steps', type=int, default=5000)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=8)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--sigma_bridge', type=float, default=0.03)
+ parser.add_argument('--eval_every', type=int, default=200)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/toy_lq')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/toy_lq_sweep.py b/experiments/toy_lq_sweep.py
new file mode 100644
index 0000000..ae82ef0
--- /dev/null
+++ b/experiments/toy_lq_sweep.py
@@ -0,0 +1,243 @@
+"""
+Sweep over credit bridge hyperparameters to find a configuration
+where the value field gradient actually aligns with the costate.
+
+Key hypothesis: the credit bridge needs sufficient noise (sigma_bridge)
+and temperature (lambda) to make V_phi sensitive to cost-relevant directions.
+"""
+import os
+import sys
+import json
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from itertools import product
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from experiments.toy_lq import (
+ generate_stable_dynamics, rollout_forward, terminal_loss,
+ exact_costate, make_forward_fn_from_layer
+)
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test
+
+
+def run_credit_bridge_config(config, device):
+ """Run credit bridge with specific hyperparameters and return final metrics."""
+ d = 64
+ m = 10
+ L = 12
+ sigma = 0.03
+ batch_size = 256
+ num_steps = config['num_steps']
+ lr = config['lr']
+ lam = config['lam']
+ K = config['K']
+ ema_momentum = config['ema_momentum']
+ sigma_bridge = config['sigma_bridge']
+ hidden_dim = config.get('hidden_dim', 128)
+ use_ln = config.get('use_ln', True)
+
+ torch.manual_seed(42)
+ np.random.seed(42)
+
+ Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=42)
+ C = torch.randn(m, d, device=device) / np.sqrt(d)
+
+ # Value net - optionally without LayerNorm
+ value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=hidden_dim, num_layers=2).to(device)
+ if not use_ln:
+ value_net.ln = nn.Identity()
+
+ value_net_ema = create_ema_model(value_net)
+ opt_value = optim.Adam(value_net.parameters(), lr=lr)
+
+ best_cos = -1.0
+ best_step = 0
+ history = []
+
+ for step in range(1, num_steps + 1):
+ h0 = torch.randn(batch_size, d, device=device)
+ y = torch.randn(batch_size, m, device=device)
+ hiddens = rollout_forward(h0, Ms, sigma, L, device)
+ hL = hiddens[L]
+ e_T = hL @ C.T - y
+ s = e_T.detach()
+ true_loss = terminal_loss(hL.detach(), C, y).detach()
+
+ # Terminal boundary
+ hL_det = hL.detach()
+ t_L = torch.ones(batch_size, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ t_l_next = torch.full((batch_size,), (l + 1) / L, device=device)
+
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn(batch_size, d, device=device)
+ h_noisy = h_next_det + noise
+ V_next = value_net_ema(h_noisy, t_l_next, s)
+ log_terms.append(-V_next / lam)
+
+ log_terms_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+
+ loss_bridge = loss_bridge / L
+ total_loss = loss_term + loss_bridge
+
+ opt_value.zero_grad()
+ total_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ opt_value.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ # Quick evaluation
+ if step % 500 == 0 or step == num_steps:
+ with torch.no_grad():
+ eval_batch = 128
+ h0_e = torch.randn(eval_batch, d, device=device)
+ y_e = torch.randn(eval_batch, m, device=device)
+ hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device)
+ hL_e = hiddens_e[L]
+ e_T_e = hL_e @ C.T - y_e
+ s_e = e_T_e.detach()
+ costates = exact_costate(hiddens_e, Ms, C, y_e, device)
+
+ cos_list = []
+ rho_list = []
+ nudge_list = []
+ for l in range(L):
+ h_l = hiddens_e[l].detach()
+ t_l = torch.full((eval_batch,), l / L, device=device)
+ a_exact = costates[l].detach()
+
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s_e)
+ a_credit = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0]
+
+ cos_list.append(cosine_similarity_batch(a_credit, a_exact))
+
+ fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device)
+ rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16)
+ rho_list.append(rho)
+ nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01)
+ nudge_list.append(nud)
+
+ avg_cos = np.mean(cos_list)
+ avg_rho = np.mean(rho_list)
+ avg_nudge = np.mean(nudge_list)
+
+ if avg_cos > best_cos:
+ best_cos = avg_cos
+ best_step = step
+
+ history.append({
+ 'step': step,
+ 'avg_cos': avg_cos,
+ 'avg_rho': avg_rho,
+ 'avg_nudge': avg_nudge,
+ 'loss_term': loss_term.item(),
+ 'loss_bridge': loss_bridge.item(),
+ })
+
+ return {
+ 'best_cos': best_cos,
+ 'best_step': best_step,
+ 'final_cos': history[-1]['avg_cos'],
+ 'final_rho': history[-1]['avg_rho'],
+ 'final_nudge': history[-1]['avg_nudge'],
+ 'history': history,
+ }
+
+
+def main():
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+ print(f"Device: {device}")
+
+ # Sweep configurations
+ configs = [
+ # Baseline (original)
+ {'name': 'base', 'lam': 0.1, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Larger noise
+ {'name': 'noise_0.1', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Much larger noise
+ {'name': 'noise_0.3', 'lam': 0.1, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Larger lambda
+ {'name': 'lam_1.0', 'lam': 1.0, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Large noise + large lambda
+ {'name': 'noise_lam', 'lam': 1.0, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # No LayerNorm
+ {'name': 'no_ln', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False},
+ # Larger value net
+ {'name': 'big_vnet', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 256, 'use_ln': True},
+ # Slower EMA
+ {'name': 'ema_0.999', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.999, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # More K samples
+ {'name': 'K16', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 16, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Larger noise + large lambda + no LN
+ {'name': 'best_combo', 'lam': 1.0, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False},
+ # Very large sigma
+ {'name': 'noise_1.0', 'lam': 1.0, 'sigma_bridge': 1.0, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Lower lr
+ {'name': 'lr_3e-4', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 3e-4,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ ]
+
+ results = {}
+ for cfg in configs:
+ name = cfg.pop('name')
+ print(f"\n{'='*50}")
+ print(f"Config: {name}")
+ print(f" {cfg}")
+ res = run_credit_bridge_config(cfg, device)
+ results[name] = res
+ print(f" Best cos: {res['best_cos']:.4f} (step {res['best_step']})")
+ print(f" Final cos: {res['final_cos']:.4f}, rho: {res['final_rho']:.4f}, nudge: {res['final_nudge']:.4f}")
+ cfg['name'] = name # restore
+
+ # Print summary
+ print("\n" + "="*80)
+ print("SWEEP SUMMARY")
+ print("="*80)
+ print(f"{'Config':<20} {'Best Cos':<12} {'Final Cos':<12} {'Final Rho':<12} {'Final Nudge':<12}")
+ print("-"*68)
+ for name, res in results.items():
+ print(f"{name:<20} {res['best_cos']:<12.4f} {res['final_cos']:<12.4f} "
+ f"{res['final_rho']:<12.4f} {res['final_nudge']:<12.4f}")
+
+ # Save
+ os.makedirs('results/toy_lq', exist_ok=True)
+ with open('results/toy_lq/sweep_results.json', 'w') as f:
+ json.dump(results, f, indent=2)
+ print("\nSaved to results/toy_lq/sweep_results.json")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/toy_lq_v2.py b/experiments/toy_lq_v2.py
new file mode 100644
index 0000000..ab766b6
--- /dev/null
+++ b/experiments/toy_lq_v2.py
@@ -0,0 +1,327 @@
+"""
+Phase A v2: Enhanced toy LQ experiment.
+
+Key improvements over v1:
+1. Terminal gradient matching: V_phi at terminal layer should have grad_h V matching
+ the exact terminal gradient (this is LOCAL info, no hidden BP needed).
+2. Larger noise sweep integrated.
+3. Optional FM auxiliary for gradient smoothness.
+4. Better diagnostics.
+
+The terminal gradient a_L = C^T(C h_L - y) is computed from output layer only,
+so using it is allowed under the "no hidden BP anchor" constraint.
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from experiments.toy_lq import (
+ generate_stable_dynamics, rollout_forward, terminal_loss,
+ exact_costate, make_forward_fn_from_layer
+)
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test
+)
+
+
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+
+ d = args.d_hidden
+ m = args.output_dim
+ L = args.num_layers
+ sigma = args.sigma
+ batch_size = args.batch_size
+ num_steps = args.num_steps
+ lr = args.lr_fb
+ lam = args.lam
+ K = args.K
+ ema_momentum = args.ema_momentum
+ sigma_bridge = args.sigma_bridge
+
+ print(f"=== Toy LQ v2 Experiment ===")
+ print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}")
+ print(f"lam={lam}, sigma_bridge={sigma_bridge}, K={K}")
+ print(f"terminal_grad_weight={args.term_grad_weight}")
+ print(f"fm_weight={args.fm_weight}")
+ print(f"device={device}")
+
+ Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed)
+ C = torch.randn(m, d, device=device) / np.sqrt(d)
+
+ # DFA
+ Bs_dfa = [torch.randn(d, m, device=device) / np.sqrt(m) for _ in range(L + 1)]
+
+ # State Bridge
+ state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=128, num_layers=2).to(device)
+ opt_state = optim.Adam(state_bridge.parameters(), lr=lr)
+
+ # Credit Bridge
+ value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=args.vnet_hidden, num_layers=args.vnet_layers).to(device)
+ value_net_ema = create_ema_model(value_net)
+ opt_value = optim.Adam(value_net.parameters(), lr=lr)
+
+ log = {key: [] for key in [
+ 'steps',
+ 'dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos',
+ 'dfa_rho', 'state_rho', 'credit_rho',
+ 'dfa_nudge', 'state_nudge', 'credit_nudge',
+ 'bridge_residual', 'state_bridge_loss', 'credit_bridge_loss',
+ 'term_loss', 'bridge_loss', 'term_grad_loss', 'fm_loss',
+ ]}
+
+ for step in range(1, num_steps + 1):
+ h0 = torch.randn(batch_size, d, device=device)
+ y = torch.randn(batch_size, m, device=device)
+ hiddens = rollout_forward(h0, Ms, sigma, L, device)
+ hL = hiddens[L]
+ e_T = hL @ C.T - y
+ s = e_T.detach()
+
+ # ---- Train State Bridge ----
+ state_loss = 0.0
+ hL_det = hL.detach()
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ pred_hL = state_bridge(h_l_det, t_l, s)
+ state_loss = state_loss + ((pred_hL - hL_det) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+
+ opt_state.zero_grad()
+ state_loss.backward()
+ opt_state.step()
+
+ # ---- Train Credit Bridge ----
+ # 1. Terminal boundary: V(h_L, 1, s) ≈ Phi(h_L, y)
+ hL_det = hL.detach()
+ t_L = torch.ones(batch_size, device=device)
+ true_loss = terminal_loss(hL_det, C, y).detach()
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # 2. Terminal gradient matching: grad_h V(h_L, 1, s) ≈ a_L^exact
+ # This uses only terminal-local information (no hidden BP)
+ loss_term_grad = torch.tensor(0.0, device=device)
+ if args.term_grad_weight > 0:
+ hL_req = hL.detach().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ # Exact terminal gradient: C^T (C h_L - y)
+ a_L_exact = (e_T @ C).detach() # (batch, d) -- stop grad on target
+ loss_term_grad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ # 3. Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ t_l_next = torch.full((batch_size,), (l + 1) / L, device=device)
+
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn(batch_size, d, device=device)
+ h_noisy = h_next_det + noise
+ V_next = value_net_ema(h_noisy, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_terms_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ # 4. FM auxiliary (optional): enforce gradient smoothness
+ loss_fm = torch.tensor(0.0, device=device)
+ if args.fm_weight > 0:
+ for l in range(L):
+ tau = torch.rand(batch_size, 1, device=device)
+ h_l_det = hiddens[l].detach()
+ h_next_det = hiddens[l + 1].detach()
+ f_l = h_next_det - h_l_det # residual
+
+ eps = torch.randn(batch_size, d, device=device)
+ h_mid = h_l_det + tau * f_l + (tau * (1 - tau)).sqrt() * sigma_bridge * eps
+ h_mid.requires_grad_(True)
+
+ t_mid = torch.full((batch_size, 1), 0, device=device)
+ t_mid = (l + tau) / L
+ t_mid_flat = t_mid.squeeze(-1)
+
+ V_mid = value_net(h_mid, t_mid_flat, s)
+ grad_V_mid = torch.autograd.grad(V_mid.sum(), h_mid, create_graph=True)[0]
+
+ # Interpolated target gradient
+ # Get a_l and a_{l+1} from current value net (no create_graph for targets)
+ h_l_r = h_l_det.clone().requires_grad_(True)
+ t_l_v = torch.full((batch_size,), l / L, device=device)
+ V_l_ = value_net(h_l_r, t_l_v, s)
+ a_l = torch.autograd.grad(V_l_.sum(), h_l_r, create_graph=False)[0].detach()
+
+ h_next_r = h_next_det.clone().requires_grad_(True)
+ t_next_v = torch.full((batch_size,), (l + 1) / L, device=device)
+ V_next_ = value_net(h_next_r, t_next_v, s)
+ a_next = torch.autograd.grad(V_next_.sum(), h_next_r, create_graph=False)[0].detach()
+
+ target_grad = ((1 - tau) * a_l + tau * a_next).detach()
+ loss_fm = loss_fm + ((grad_V_mid - target_grad) ** 2).sum(dim=-1).mean()
+ loss_fm = loss_fm / L
+
+ total_loss = (loss_term
+ + loss_bridge
+ + args.term_grad_weight * loss_term_grad
+ + args.fm_weight * loss_fm)
+
+ opt_value.zero_grad()
+ total_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ opt_value.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ # ---- Evaluation ----
+ if step % args.eval_every == 0 or step == 1:
+ with torch.no_grad():
+ eval_batch = 128
+ h0_e = torch.randn(eval_batch, d, device=device)
+ y_e = torch.randn(eval_batch, m, device=device)
+ hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device)
+ hL_e = hiddens_e[L]
+ e_T_e = hL_e @ C.T - y_e
+ s_e = e_T_e.detach()
+ costates = exact_costate(hiddens_e, Ms, C, y_e, device)
+
+ dfa_cos, state_cos, credit_cos = [], [], []
+ dfa_rho, state_rho, credit_rho = [], [], []
+ dfa_nudge, state_nudge, credit_nudge = [], [], []
+ bridge_res_list = []
+
+ for l in range(L):
+ h_l = hiddens_e[l].detach()
+ a_exact = costates[l].detach()
+ t_l = torch.full((eval_batch,), l / L, device=device)
+
+ # DFA
+ a_dfa = e_T_e @ Bs_dfa[l].T
+ # State bridge
+ h_l_r1 = h_l.clone().requires_grad_(True)
+ pred_hL = state_bridge(h_l_r1, t_l, s_e)
+ pred_out = pred_hL @ C.T
+ pred_loss = 0.5 * ((pred_out - y_e) ** 2).sum(dim=-1)
+ a_state = torch.autograd.grad(pred_loss.sum(), h_l_r1, create_graph=False)[0]
+ # Credit bridge
+ h_l_r2 = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_r2, t_l, s_e)
+ a_credit = torch.autograd.grad(V_l.sum(), h_l_r2, create_graph=False)[0]
+
+ dfa_cos.append(cosine_similarity_batch(a_dfa, a_exact))
+ state_cos.append(cosine_similarity_batch(a_state, a_exact))
+ credit_cos.append(cosine_similarity_batch(a_credit, a_exact))
+
+ fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device)
+
+ dfa_rho.append(perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16))
+ state_rho.append(perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16))
+ credit_rho.append(perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16))
+
+ dfa_nudge.append(nudging_test(h_l, a_dfa, fwd_fn, eta=0.01))
+ state_nudge.append(nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01))
+ credit_nudge.append(nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01))
+
+ avg = lambda x: float(np.mean(x))
+ log['steps'].append(step)
+ log['dfa_costate_cos'].append(avg(dfa_cos))
+ log['state_costate_cos'].append(avg(state_cos))
+ log['credit_costate_cos'].append(avg(credit_cos))
+ log['dfa_rho'].append(avg(dfa_rho))
+ log['state_rho'].append(avg(state_rho))
+ log['credit_rho'].append(avg(credit_rho))
+ log['dfa_nudge'].append(avg(dfa_nudge))
+ log['state_nudge'].append(avg(state_nudge))
+ log['credit_nudge'].append(avg(credit_nudge))
+ log['state_bridge_loss'].append(state_loss.item())
+ log['credit_bridge_loss'].append(total_loss.item())
+ log['term_loss'].append(loss_term.item())
+ log['bridge_loss'].append(loss_bridge.item())
+ log['term_grad_loss'].append(loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else loss_term_grad)
+ log['fm_loss'].append(loss_fm.item() if isinstance(loss_fm, torch.Tensor) else loss_fm)
+
+ print(f"Step {step}/{num_steps}")
+ print(f" Costate cos - DFA: {avg(dfa_cos):.4f}, State: {avg(state_cos):.4f}, Credit: {avg(credit_cos):.4f}")
+ print(f" Perturb rho - DFA: {avg(dfa_rho):.4f}, State: {avg(state_rho):.4f}, Credit: {avg(credit_rho):.4f}")
+ print(f" Nudging - DFA: {avg(dfa_nudge):.4f}, State: {avg(state_nudge):.4f}, Credit: {avg(credit_nudge):.4f}")
+ print(f" Losses - term: {loss_term.item():.4f}, bridge: {loss_bridge.item():.4f}, "
+ f"tgrad: {loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else 0:.4f}, "
+ f"fm: {loss_fm.item() if isinstance(loss_fm, torch.Tensor) else 0:.4f}")
+ print(f" Per-layer credit cos: {['%.3f' % x for x in credit_cos]}")
+
+ # Save
+ os.makedirs(args.output_dir, exist_ok=True)
+ results = {
+ 'config': vars(args),
+ 'log': log,
+ 'final_per_layer': {
+ 'dfa_costate_cos': dfa_cos,
+ 'state_costate_cos': state_cos,
+ 'credit_costate_cos': credit_cos,
+ 'dfa_rho': dfa_rho,
+ 'state_rho': state_rho,
+ 'credit_rho': credit_rho,
+ 'dfa_nudge': dfa_nudge,
+ 'state_nudge': state_nudge,
+ 'credit_nudge': credit_nudge,
+ }
+ }
+ tag = f"seed{args.seed}_lam{args.lam}_sig{args.sigma_bridge}_tgw{args.term_grad_weight}_fm{args.fm_weight}"
+ out_path = os.path.join(args.output_dir, f'toy_lq_v2_{tag}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+ return results
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Toy LQ v2')
+ parser.add_argument('--d_hidden', type=int, default=64)
+ parser.add_argument('--output_dim', type=int, default=10)
+ parser.add_argument('--num_layers', type=int, default=12)
+ parser.add_argument('--sigma', type=float, default=0.03)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--num_steps', type=int, default=8000)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=8)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--sigma_bridge', type=float, default=0.1)
+ parser.add_argument('--eval_every', type=int, default=500)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/toy_lq')
+ parser.add_argument('--vnet_hidden', type=int, default=256)
+ parser.add_argument('--vnet_layers', type=int, default=3)
+ # Key new options
+ parser.add_argument('--term_grad_weight', type=float, default=1.0,
+ help='Weight for terminal gradient matching loss')
+ parser.add_argument('--fm_weight', type=float, default=0.0,
+ help='Weight for FM gradient smoothness auxiliary')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/methods/__init__.py b/methods/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/methods/__init__.py
diff --git a/metrics/__init__.py b/metrics/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/metrics/__init__.py
diff --git a/metrics/__pycache__/__init__.cpython-313.pyc b/metrics/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000..0595726
--- /dev/null
+++ b/metrics/__pycache__/__init__.cpython-313.pyc
Binary files differ
diff --git a/metrics/__pycache__/credit_metrics.cpython-313.pyc b/metrics/__pycache__/credit_metrics.cpython-313.pyc
new file mode 100644
index 0000000..ef62388
--- /dev/null
+++ b/metrics/__pycache__/credit_metrics.cpython-313.pyc
Binary files differ
diff --git a/metrics/credit_metrics.py b/metrics/credit_metrics.py
new file mode 100644
index 0000000..516dca2
--- /dev/null
+++ b/metrics/credit_metrics.py
@@ -0,0 +1,156 @@
+"""
+Credit assignment diagnostic metrics:
+1. Exact costate cosine (for toy LQ)
+2. Local perturbation correlation rho_l
+3. Nudging test Delta_l^nudge
+4. Offline BP cosine Gamma_l
+5. Bridge residual R_l
+6. Feature drift M_l
+"""
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.stats import pearsonr
+
+
+def cosine_similarity_batch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """Compute cosine similarity between a and b along last dim, averaged over batch."""
+ a_flat = a.reshape(a.shape[0], -1)
+ b_flat = b.reshape(b.shape[0], -1)
+ cos = F.cosine_similarity(a_flat, b_flat, dim=-1)
+ return cos.mean().item()
+
+
+def perturbation_correlation(h_l, a_l, forward_fn, epsilon=1e-3, M=32):
+ """
+ Compute local perturbation correlation rho_l.
+
+ Args:
+ h_l: (batch, d) hidden state at layer l
+ a_l: (batch, d) credit signal at layer l
+ forward_fn: callable that takes h_l -> scalar loss (averaged over batch dims handled inside)
+ epsilon: perturbation magnitude
+ M: number of random directions
+
+ Returns:
+ rho: Pearson correlation between predicted and true loss changes
+ """
+ batch_size, d = h_l.shape
+ device = h_l.device
+
+ pred_list = []
+ true_list = []
+
+ base_loss = forward_fn(h_l) # (batch,) or scalar
+
+ for _ in range(M):
+ v = torch.randn(batch_size, d, device=device)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+
+ # Predicted change: <a_l, epsilon * v>
+ delta_pred = (a_l * (epsilon * v)).sum(dim=-1) # (batch,)
+
+ # True change: forward from perturbed h
+ perturbed_loss = forward_fn(h_l + epsilon * v) # (batch,)
+ delta_true = perturbed_loss - base_loss # (batch,)
+
+ pred_list.append(delta_pred.detach().cpu().numpy())
+ true_list.append(delta_true.detach().cpu().numpy())
+
+ pred_arr = np.concatenate(pred_list)
+ true_arr = np.concatenate(true_list)
+
+ if np.std(pred_arr) < 1e-12 or np.std(true_arr) < 1e-12:
+ return 0.0
+
+ rho, _ = pearsonr(pred_arr, true_arr)
+ return float(rho)
+
+
+def nudging_test(h_l, a_l, forward_fn, eta=0.01):
+ """
+ Nudging test: check if moving h_l in -a_l direction decreases loss.
+
+ Args:
+ h_l: (batch, d) hidden state
+ a_l: (batch, d) credit signal
+ forward_fn: callable h -> loss per sample (batch,)
+ eta: step size
+
+ Returns:
+ mean delta_nudge (negative is good)
+ """
+ rms_a = (a_l ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_normed = a_l / rms_a
+ h_nudged = h_l - eta * a_normed
+
+ base_loss = forward_fn(h_l)
+ nudged_loss = forward_fn(h_nudged)
+ delta = (nudged_loss - base_loss).mean().item()
+ return delta
+
+
+def offline_bp_cosine(a_l, bp_grad_l):
+ """
+ Compute offline BP cosine similarity.
+ a_l: (batch, d) credit signal
+ bp_grad_l: (batch, d) true BP gradient at layer l
+ """
+ return cosine_similarity_batch(a_l, bp_grad_l)
+
+
+def bridge_residual(V_phi, V_bar_phi, h_l, t_l, s, h_l_next_noisy_list, t_l_next, lam=0.1):
+ """
+ Compute bridge residual R_l.
+
+ Args:
+ V_phi: value network
+ V_bar_phi: EMA target value network
+ h_l: (batch, d)
+ t_l: (batch,)
+ s: (batch, s_dim)
+ h_l_next_noisy_list: list of K tensors (batch, d), noisy next states
+ t_l_next: (batch,)
+ lam: temperature
+
+ Returns:
+ mean absolute bridge residual
+ """
+ with torch.no_grad():
+ V_current = V_phi(h_l, t_l, s) # (batch,)
+
+ # Compute soft-min target
+ K = len(h_l_next_noisy_list)
+ log_terms = []
+ for h_next in h_l_next_noisy_list:
+ V_next = V_bar_phi(h_next, t_l_next, s) # (batch,)
+ log_terms.append(-V_next / lam)
+
+ log_terms = torch.stack(log_terms, dim=-1) # (batch, K)
+ V_target = -lam * torch.logsumexp(log_terms, dim=-1) + lam * np.log(K)
+
+ residual = (V_current - V_target).abs().mean().item()
+ return residual
+
+
+def feature_drift(model_init_params, model_final_params):
+ """
+ Compute per-layer feature drift M_l = ||W_final - W_init||_F / ||W_init||_F.
+
+ Args:
+ model_init_params: dict of {name: tensor} initial parameters
+ model_final_params: dict of {name: tensor} final parameters
+
+ Returns:
+ dict of {name: drift_ratio}
+ """
+ drifts = {}
+ for name in model_init_params:
+ if name in model_final_params:
+ w_init = model_init_params[name]
+ w_final = model_final_params[name]
+ init_norm = w_init.norm().item()
+ if init_norm > 1e-8:
+ drift = (w_final - w_init).norm().item() / init_norm
+ drifts[name] = drift
+ return drifts
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/models/__init__.py
diff --git a/models/__pycache__/__init__.cpython-313.pyc b/models/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000..cb3f264
--- /dev/null
+++ b/models/__pycache__/__init__.cpython-313.pyc
Binary files differ
diff --git a/models/__pycache__/residual_mlp.cpython-313.pyc b/models/__pycache__/residual_mlp.cpython-313.pyc
new file mode 100644
index 0000000..c758f50
--- /dev/null
+++ b/models/__pycache__/residual_mlp.cpython-313.pyc
Binary files differ
diff --git a/models/__pycache__/state_bridge.cpython-313.pyc b/models/__pycache__/state_bridge.cpython-313.pyc
new file mode 100644
index 0000000..69e1071
--- /dev/null
+++ b/models/__pycache__/state_bridge.cpython-313.pyc
Binary files differ
diff --git a/models/__pycache__/value_net.cpython-313.pyc b/models/__pycache__/value_net.cpython-313.pyc
new file mode 100644
index 0000000..a6187ee
--- /dev/null
+++ b/models/__pycache__/value_net.cpython-313.pyc
Binary files differ
diff --git a/models/residual_mlp.py b/models/residual_mlp.py
new file mode 100644
index 0000000..c16778c
--- /dev/null
+++ b/models/residual_mlp.py
@@ -0,0 +1,73 @@
+"""
+Deep Residual MLP for classification.
+Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head.
+Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l))
+"""
+import torch
+import torch.nn as nn
+
+
+class ResidualBlock(nn.Module):
+ """Single pre-LayerNorm residual MLP block."""
+
+ def __init__(self, d_hidden: int):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w1 = nn.Linear(d_hidden, d_hidden)
+ self.w2 = nn.Linear(d_hidden, d_hidden)
+ # Small init for residual branch
+ nn.init.normal_(self.w2.weight, std=0.01)
+ nn.init.zeros_(self.w2.bias)
+
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
+ """Returns the residual F_l(h), NOT h + F_l(h)."""
+ z = self.ln(h)
+ z = self.w1(z)
+ z = torch.nn.functional.gelu(z)
+ z = self.w2(z)
+ return z
+
+
+class ResidualMLP(nn.Module):
+ """Deep residual MLP: embed -> L blocks -> LN -> output head."""
+
+ def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int):
+ super().__init__()
+ self.embed = nn.Linear(input_dim, d_hidden)
+ self.blocks = nn.ModuleList([ResidualBlock(d_hidden) for _ in range(num_blocks)])
+ self.out_ln = nn.LayerNorm(d_hidden)
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+
+ def forward(self, x: torch.Tensor, return_hidden: bool = False):
+ """
+ Args:
+ x: (batch, input_dim)
+ return_hidden: if True, also return list of hidden states [h_0, ..., h_L]
+ Returns:
+ logits: (batch, num_classes)
+ hiddens: list of (batch, d_hidden) if return_hidden
+ """
+ h = self.embed(x)
+ hiddens = [h] if return_hidden else None
+
+ for block in self.blocks:
+ f = block(h)
+ h = h + f
+ if return_hidden:
+ hiddens.append(h)
+
+ logits = self.out_head(self.out_ln(h))
+
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h: torch.Tensor, start_layer: int):
+ """Run forward from a given layer index to output. Used for perturbation tests."""
+ for i in range(start_layer, self.num_blocks):
+ f = self.blocks[i](h)
+ h = h + f
+ logits = self.out_head(self.out_ln(h))
+ return logits
diff --git a/models/state_bridge.py b/models/state_bridge.py
new file mode 100644
index 0000000..0a0e7aa
--- /dev/null
+++ b/models/state_bridge.py
@@ -0,0 +1,35 @@
+"""
+State Bridge predictor G_psi(h_l, t_l, s) -> predicted h_L.
+Used by the State Bridge method.
+"""
+import torch
+import torch.nn as nn
+from .value_net import SinusoidalTimeEmbed
+
+
+class StateBridgeNet(nn.Module):
+ """
+ State predictor G_psi(h_l, t_l, s) -> predicted terminal state h_L.
+ """
+
+ def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32,
+ hidden_dim: int = 256, num_layers: int = 3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
+ """Returns predicted h_L as (batch, d_hidden)."""
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
diff --git a/models/value_net.py b/models/value_net.py
new file mode 100644
index 0000000..3c72f75
--- /dev/null
+++ b/models/value_net.py
@@ -0,0 +1,77 @@
+"""
+Value network V_phi(h_l, t_l, s) -> scalar.
+Used by the Credit Bridge method.
+Input: [LN(h_l), time_embed(t_l), s] concatenated.
+"""
+import torch
+import torch.nn as nn
+import math
+import copy
+
+
+class SinusoidalTimeEmbed(nn.Module):
+ """Sinusoidal positional encoding for scalar depth-time t_l = l/L."""
+
+ def __init__(self, embed_dim: int):
+ super().__init__()
+ self.embed_dim = embed_dim
+
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
+ """t: (batch,) or (batch, 1) scalar in [0,1]."""
+ if t.dim() == 1:
+ t = t.unsqueeze(-1) # (batch, 1)
+ half = self.embed_dim // 2
+ freqs = torch.exp(
+ -math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half
+ )
+ args = t * freqs.unsqueeze(0) # (batch, half)
+ return torch.cat([torch.sin(args), torch.cos(args)], dim=-1) # (batch, embed_dim)
+
+
+class ValueNet(nn.Module):
+ """
+ Scalar value network V_phi(h_l, t_l, s).
+ Inputs:
+ h: hidden state (batch, d_hidden)
+ t: depth-time scalar (batch,) in [0, 1]
+ s: terminal modulation code (batch, s_dim)
+ Output:
+ V: scalar (batch,)
+ """
+
+ def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32,
+ hidden_dim: int = 256, num_layers: int = 3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, 1))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
+ """Returns V(h, t, s) as (batch,) scalar."""
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp).squeeze(-1)
+
+
+def create_ema_model(model: nn.Module) -> nn.Module:
+ """Create an EMA copy of a model."""
+ ema = copy.deepcopy(model)
+ for p in ema.parameters():
+ p.requires_grad_(False)
+ return ema
+
+
+@torch.no_grad()
+def update_ema(model: nn.Module, ema_model: nn.Module, momentum: float = 0.99):
+ """Update EMA model parameters."""
+ for p, ep in zip(model.parameters(), ema_model.parameters()):
+ ep.data.mul_(momentum).add_(p.data, alpha=1 - momentum)
diff --git a/report/toy_bridge_residual.png b/report/toy_bridge_residual.png
new file mode 100644
index 0000000..03eeb47
--- /dev/null
+++ b/report/toy_bridge_residual.png
Binary files differ
diff --git a/report/toy_per_layer_diagnostics.png b/report/toy_per_layer_diagnostics.png
new file mode 100644
index 0000000..d31b188
--- /dev/null
+++ b/report/toy_per_layer_diagnostics.png
Binary files differ
diff --git a/report/toy_term_grad_effect.png b/report/toy_term_grad_effect.png
new file mode 100644
index 0000000..13f0458
--- /dev/null
+++ b/report/toy_term_grad_effect.png
Binary files differ
diff --git a/report/toy_training_curves.png b/report/toy_training_curves.png
new file mode 100644
index 0000000..cc3532b
--- /dev/null
+++ b/report/toy_training_curves.png
Binary files differ
diff --git a/results/smoke_test/results_fashionmnist.json b/results/smoke_test/results_fashionmnist.json
new file mode 100644
index 0000000..8fd82c0
--- /dev/null
+++ b/results/smoke_test/results_fashionmnist.json
@@ -0,0 +1,511 @@
+{
+ "42": {
+ "bp": {
+ "log": {
+ "train_loss": [
+ 0.7028828698158264,
+ 0.5331447437604269,
+ 0.4640675885995229,
+ 0.416527880080541,
+ 0.3784152720451355
+ ],
+ "train_acc": [
+ 0.73755,
+ 0.8002166666666667,
+ 0.8256666666666667,
+ 0.8436666666666667,
+ 0.8591833333333333
+ ],
+ "test_acc": [
+ 0.7939,
+ 0.8157,
+ 0.8379,
+ 0.8606,
+ 0.8658
+ ]
+ },
+ "diagnostics": {
+ "bp_cosine": [
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0
+ ],
+ "perturbation_rho": [
+ 0.9998708367347717,
+ 0.9998407959938049,
+ 0.9997469186782837,
+ 0.9998122453689575,
+ 0.9997262358665466,
+ 0.9996491074562073,
+ 0.9996585845947266,
+ 0.9995328783988953
+ ],
+ "nudging": {
+ "0.001": [
+ -0.0019641336984932423,
+ -0.00174636859446764,
+ -0.00154352025128901,
+ -0.0013810225063934922,
+ -0.0012467722408473492,
+ -0.0011190228397026658,
+ -0.0010304137831553817,
+ -0.0009783159475773573
+ ],
+ "0.003": [
+ -0.005858615506440401,
+ -0.005213461350649595,
+ -0.004612031392753124,
+ -0.004128905013203621,
+ -0.0037290011532604694,
+ -0.003348270896822214,
+ -0.003083921270444989,
+ -0.002928499598056078
+ ],
+ "0.01": [
+ -0.019138235598802567,
+ -0.017080917954444885,
+ -0.015158241614699364,
+ -0.013598069548606873,
+ -0.01229821052402258,
+ -0.01105712354183197,
+ -0.010194781236350536,
+ -0.009686501696705818
+ ]
+ }
+ },
+ "drift": {
+ "embed.weight": 1.1395236897587921,
+ "embed.bias": 0.6760135861021652,
+ "blocks.0.ln.weight": 0.0347786583006382,
+ "blocks.0.w1.weight": 0.8267150385727446,
+ "blocks.0.w1.bias": 0.7665003752239994,
+ "blocks.0.w2.weight": 2.522163826858937,
+ "blocks.1.ln.weight": 0.0373835414648056,
+ "blocks.1.w1.weight": 0.8094006579319112,
+ "blocks.1.w1.bias": 0.7074648417912711,
+ "blocks.1.w2.weight": 2.427488314293417,
+ "blocks.2.ln.weight": 0.0362338162958622,
+ "blocks.2.w1.weight": 0.7981608599321399,
+ "blocks.2.w1.bias": 0.6653993627306621,
+ "blocks.2.w2.weight": 2.3264717248101827,
+ "blocks.3.ln.weight": 0.03489774093031883,
+ "blocks.3.w1.weight": 0.8024856088475522,
+ "blocks.3.w1.bias": 0.6243261101573547,
+ "blocks.3.w2.weight": 2.2732641732905714,
+ "blocks.4.ln.weight": 0.036862026900053024,
+ "blocks.4.w1.weight": 0.7702369058900084,
+ "blocks.4.w1.bias": 0.671902653164234,
+ "blocks.4.w2.weight": 2.11229934068397,
+ "blocks.5.ln.weight": 0.04049132019281387,
+ "blocks.5.w1.weight": 0.74033776504041,
+ "blocks.5.w1.bias": 0.6447910547850285,
+ "blocks.5.w2.weight": 1.9146335569252138,
+ "blocks.6.ln.weight": 0.03797098994255066,
+ "blocks.6.w1.weight": 0.7241522377185389,
+ "blocks.6.w1.bias": 0.6486550903936706,
+ "blocks.6.w2.weight": 1.8210870137239685,
+ "blocks.7.ln.weight": 0.03962903097271919,
+ "blocks.7.w1.weight": 0.6992516513120532,
+ "blocks.7.w1.bias": 0.7021825186584477,
+ "blocks.7.w2.weight": 1.7902862835380957,
+ "out_ln.weight": 0.026629405096173286,
+ "out_head.weight": 0.5610428179907003,
+ "out_head.bias": 0.24151687322978704
+ }
+ },
+ "dfa": {
+ "log": {
+ "train_loss": [
+ 1.4100907169977823,
+ 1.4334057479222615,
+ 1.4326289967854817,
+ 1.4254953683853149,
+ 1.4169784986495972
+ ],
+ "train_acc": [
+ 0.4805,
+ 0.49806666666666666,
+ 0.49725,
+ 0.5031166666666667,
+ 0.5098
+ ],
+ "test_acc": [
+ 0.5175,
+ 0.5282,
+ 0.5025,
+ 0.5027,
+ 0.5338
+ ]
+ },
+ "diagnostics": {
+ "bp_cosine": [
+ 0.2141590267419815,
+ 0.0797841027379036,
+ 0.04707183688879013,
+ 0.0032249214127659798,
+ 0.024886084720492363,
+ -0.0050605954602360725,
+ 0.009758025407791138,
+ 0.020020857453346252
+ ],
+ "perturbation_rho": [
+ 0.07338915765285492,
+ -0.004745986312627792,
+ 0.018682830035686493,
+ 0.031185101717710495,
+ -0.02063235454261303,
+ 0.0006608979310840368,
+ 0.0,
+ 0.04759033024311066
+ ],
+ "nudging": {
+ "0.001": [
+ -3.5976991057395935e-06,
+ 2.3283064365386963e-09,
+ -9.313225746154785e-10,
+ -9.313225746154785e-09,
+ 4.656612873077393e-10,
+ -3.725290298461914e-09,
+ 1.3969838619232178e-09,
+ 9.313225746154785e-10
+ ],
+ "0.003": [
+ -1.0745832696557045e-05,
+ 2.3283064365386963e-09,
+ 1.6298145055770874e-09,
+ -2.3283064365386963e-09,
+ 6.752088665962219e-09,
+ 1.0244548320770264e-08,
+ -4.6566128730773926e-09,
+ 0.0
+ ],
+ "0.01": [
+ -3.5760458558797836e-05,
+ -9.313225746154785e-10,
+ 1.3271346688270569e-08,
+ -1.0244548320770264e-08,
+ -8.381903171539307e-09,
+ 4.1443854570388794e-08,
+ 8.847564458847046e-09,
+ 1.30385160446167e-08
+ ]
+ }
+ },
+ "drift": {
+ "embed.weight": 34.15137184586086,
+ "embed.bias": 25.886722942466992,
+ "blocks.0.ln.weight": 2.2413852214813232,
+ "blocks.0.w1.weight": 42.91370684219911,
+ "blocks.0.w1.bias": 42.38937429728957,
+ "blocks.0.w2.weight": 115.11173260217275,
+ "blocks.1.ln.weight": 2.0233230590820312,
+ "blocks.1.w1.weight": 36.64772802731374,
+ "blocks.1.w1.bias": 34.57367344043412,
+ "blocks.1.w2.weight": 84.1198032305499,
+ "blocks.2.ln.weight": 1.9000216722488403,
+ "blocks.2.w1.weight": 33.172328512058,
+ "blocks.2.w1.bias": 31.908254113444393,
+ "blocks.2.w2.weight": 78.32752075434591,
+ "blocks.3.ln.weight": 1.9099335670471191,
+ "blocks.3.w1.weight": 36.73019631908303,
+ "blocks.3.w1.bias": 32.60666919280332,
+ "blocks.3.w2.weight": 83.75068434979308,
+ "blocks.4.ln.weight": 1.891120195388794,
+ "blocks.4.w1.weight": 35.27032832987592,
+ "blocks.4.w1.bias": 38.017692746712825,
+ "blocks.4.w2.weight": 80.26417466790754,
+ "blocks.5.ln.weight": 2.0106024742126465,
+ "blocks.5.w1.weight": 42.09808335703852,
+ "blocks.5.w1.bias": 43.15100280108635,
+ "blocks.5.w2.weight": 94.64753309078039,
+ "blocks.6.ln.weight": 1.8941009044647217,
+ "blocks.6.w1.weight": 38.94163125135345,
+ "blocks.6.w1.bias": 38.10426587138794,
+ "blocks.6.w2.weight": 84.40488806201412,
+ "blocks.7.ln.weight": 1.9035111665725708,
+ "blocks.7.w1.weight": 38.65051912560395,
+ "blocks.7.w1.bias": 40.760402959190415,
+ "blocks.7.w2.weight": 81.97372863530312,
+ "out_ln.weight": 0.39740806818008423,
+ "out_head.weight": 3.609484615833081,
+ "out_head.bias": 0.8344298895862311
+ }
+ },
+ "state_bridge": {
+ "log": {
+ "train_loss": [
+ 1.7336509002049765,
+ 1.5851847206751506,
+ 1.8742321704864502,
+ 1.8100628153483074,
+ 1.5580067304611207
+ ],
+ "train_acc": [
+ 0.32705,
+ 0.3616333333333333,
+ 0.2679166666666667,
+ 0.31853333333333333,
+ 0.4080166666666667
+ ],
+ "test_acc": [
+ 0.4036,
+ 0.4047,
+ 0.3046,
+ 0.4005,
+ 0.4651
+ ],
+ "state_pred_error": [
+ 5076234.573784879,
+ 343992182.36586666,
+ 517900685.38026667,
+ 557639895.7226666,
+ 383439664.5205333
+ ]
+ },
+ "diagnostics": {
+ "bp_cosine": [
+ 0.35149621963500977,
+ 0.27368226647377014,
+ 0.045176729559898376,
+ 0.04587914049625397,
+ 0.06403794139623642,
+ 0.06599076837301254,
+ 0.10026843845844269,
+ 0.11267217993736267
+ ],
+ "perturbation_rho": [
+ 0.37996596097946167,
+ 0.0075237625278532505,
+ -0.017497196793556213,
+ 0.001783197745680809,
+ 0.026772135868668556,
+ 0.011043311096727848,
+ 0.0037202914245426655,
+ 0.008577261120080948
+ ],
+ "nudging": {
+ "0.001": [
+ -7.0138368755579e-05,
+ -1.0319054126739502e-06,
+ -6.495974957942963e-08,
+ -4.190951585769653e-09,
+ -3.3993273973464966e-08,
+ -2.2584572434425354e-08,
+ -6.123445928096771e-08,
+ -4.1443854570388794e-08
+ ],
+ "0.003": [
+ -0.00021021789871156216,
+ -3.0745286494493484e-06,
+ -1.0547228157520294e-07,
+ -4.866160452365875e-08,
+ -8.591450750827789e-08,
+ -6.938353180885315e-08,
+ -1.1292286217212677e-07,
+ -9.825453162193298e-08
+ ],
+ "0.01": [
+ -0.0006983885541558266,
+ -1.027202233672142e-05,
+ -3.688037395477295e-07,
+ -1.073349267244339e-07,
+ -2.153683453798294e-07,
+ -2.0815059542655945e-07,
+ -2.6938505470752716e-07,
+ -3.08966264128685e-07
+ ]
+ }
+ },
+ "drift": {
+ "embed.weight": 4.244627827603776,
+ "embed.bias": 2.284214237417933,
+ "blocks.0.ln.weight": 0.7376585006713867,
+ "blocks.0.w1.weight": 10.052587254542228,
+ "blocks.0.w1.bias": 12.863219234970343,
+ "blocks.0.w2.weight": 32.7856072339516,
+ "blocks.1.ln.weight": 0.8459606170654297,
+ "blocks.1.w1.weight": 17.10440998330836,
+ "blocks.1.w1.bias": 22.587196662115986,
+ "blocks.1.w2.weight": 48.94107595877311,
+ "blocks.2.ln.weight": 0.5776991248130798,
+ "blocks.2.w1.weight": 12.17296546064332,
+ "blocks.2.w1.bias": 14.464613959802604,
+ "blocks.2.w2.weight": 33.58238884289355,
+ "blocks.3.ln.weight": 0.6916943788528442,
+ "blocks.3.w1.weight": 11.021146543598332,
+ "blocks.3.w1.bias": 11.779720628235973,
+ "blocks.3.w2.weight": 25.18170826322489,
+ "blocks.4.ln.weight": 0.5363020300865173,
+ "blocks.4.w1.weight": 8.488676390390957,
+ "blocks.4.w1.bias": 11.968348972417077,
+ "blocks.4.w2.weight": 23.562556821259157,
+ "blocks.5.ln.weight": 0.749293863773346,
+ "blocks.5.w1.weight": 13.199470836216618,
+ "blocks.5.w1.bias": 17.384581140704626,
+ "blocks.5.w2.weight": 36.38642209120496,
+ "blocks.6.ln.weight": 0.45835214853286743,
+ "blocks.6.w1.weight": 10.662863447362852,
+ "blocks.6.w1.bias": 15.855775302559838,
+ "blocks.6.w2.weight": 28.70293410646866,
+ "blocks.7.ln.weight": 0.4122738838195801,
+ "blocks.7.w1.weight": 7.193545064718527,
+ "blocks.7.w1.bias": 7.340520731384394,
+ "blocks.7.w2.weight": 22.53669113456487,
+ "out_ln.weight": 0.06885236501693726,
+ "out_head.weight": 1.473306835980063,
+ "out_head.bias": 1.5084479134035678
+ }
+ },
+ "credit_bridge": {
+ "log": {
+ "train_loss": [
+ 2.0466531958262126,
+ 2.2737758037567137,
+ 2.280587441889445,
+ 2.25820095837911,
+ 2.270971960576375
+ ],
+ "train_acc": [
+ 0.22976666666666667,
+ 0.13306666666666667,
+ 0.13921666666666666,
+ 0.15853333333333333,
+ 0.15521666666666667
+ ],
+ "test_acc": [
+ 0.1866,
+ 0.101,
+ 0.1455,
+ 0.1616,
+ 0.0958
+ ],
+ "value_loss": [
+ 1.105676996310552,
+ 0.06329173700014751,
+ 0.02014747195293506,
+ 0.01469622576336066,
+ 0.0050645585257560015
+ ]
+ },
+ "diagnostics": {
+ "bp_cosine": [
+ -0.04117349535226822,
+ 0.006370606832206249,
+ 0.03125208616256714,
+ -0.015287259593605995,
+ -0.04197325184941292,
+ -0.021368175745010376,
+ 0.009605868719518185,
+ 0.029422588646411896
+ ],
+ "perturbation_rho": [
+ -0.09283949434757233,
+ -0.030718784779310226,
+ 0.00206748116761446,
+ 0.0,
+ -0.0060626850463449955,
+ -0.015737878158688545,
+ 0.02677903138101101,
+ -0.040315717458724976
+ ],
+ "nudging": {
+ "0.001": [
+ 3.939494490623474e-06,
+ -1.4901161193847656e-08,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "0.003": [
+ 1.1788681149482727e-05,
+ -3.725290298461914e-09,
+ 1.862645149230957e-09,
+ 0.0,
+ 0.0,
+ -1.862645149230957e-09,
+ 1.862645149230957e-09,
+ 0.0
+ ],
+ "0.01": [
+ 3.923662006855011e-05,
+ 2.60770320892334e-08,
+ 7.450580596923828e-09,
+ -1.862645149230957e-09,
+ 1.862645149230957e-09,
+ 1.862645149230957e-09,
+ 1.862645149230957e-09,
+ 5.587935447692871e-09
+ ]
+ }
+ },
+ "drift": {
+ "embed.weight": 5.5176286267072285,
+ "embed.bias": 5.84115130180781,
+ "blocks.0.ln.weight": 0.7516779899597168,
+ "blocks.0.w1.weight": 13.130250396475844,
+ "blocks.0.w1.bias": 16.418874968953887,
+ "blocks.0.w2.weight": 35.524006089100524,
+ "blocks.1.ln.weight": 1.0247337818145752,
+ "blocks.1.w1.weight": 24.629494668532566,
+ "blocks.1.w1.bias": 31.389502706524333,
+ "blocks.1.w2.weight": 66.32116257377548,
+ "blocks.2.ln.weight": 1.0405563116073608,
+ "blocks.2.w1.weight": 19.820545045709153,
+ "blocks.2.w1.bias": 21.04420170821489,
+ "blocks.2.w2.weight": 47.599865577930004,
+ "blocks.3.ln.weight": 0.698647677898407,
+ "blocks.3.w1.weight": 11.782448138509505,
+ "blocks.3.w1.bias": 11.300932589985935,
+ "blocks.3.w2.weight": 29.52897326429494,
+ "blocks.4.ln.weight": 0.7269545197486877,
+ "blocks.4.w1.weight": 12.151248949786332,
+ "blocks.4.w1.bias": 11.545512427875654,
+ "blocks.4.w2.weight": 31.11028303079019,
+ "blocks.5.ln.weight": 0.7007301449775696,
+ "blocks.5.w1.weight": 10.093534441471926,
+ "blocks.5.w1.bias": 8.779689334059729,
+ "blocks.5.w2.weight": 24.2612493038314,
+ "blocks.6.ln.weight": 0.7457646727561951,
+ "blocks.6.w1.weight": 9.34015610077582,
+ "blocks.6.w1.bias": 7.819986138941612,
+ "blocks.6.w2.weight": 24.326245888357803,
+ "blocks.7.ln.weight": 0.7317199110984802,
+ "blocks.7.w1.weight": 10.73492434511088,
+ "blocks.7.w1.bias": 9.150645079981764,
+ "blocks.7.w2.weight": 27.11222424242763,
+ "out_ln.weight": 0.07713422179222107,
+ "out_head.weight": 1.4131541672079744,
+ "out_head.bias": 0.7976411656775528
+ }
+ }
+ },
+ "config": {
+ "dataset": "fashionmnist",
+ "d_hidden": 256,
+ "num_blocks": 8,
+ "batch_size": 128,
+ "epochs": 5,
+ "lr": 0.001,
+ "lr_fb": 0.001,
+ "wd": 0.01,
+ "lam": 0.1,
+ "K": 4,
+ "sigma_bridge": 0.05,
+ "ema_momentum": 0.995,
+ "term_grad_weight": 1.0,
+ "seeds": [
+ 42
+ ],
+ "gpu": 0,
+ "output_dir": "results/smoke_test",
+ "num_classes": 10
+ }
+} \ No newline at end of file
diff --git a/results/smoke_test2/results_fashionmnist.json b/results/smoke_test2/results_fashionmnist.json
new file mode 100644
index 0000000..e4b5d70
--- /dev/null
+++ b/results/smoke_test2/results_fashionmnist.json
@@ -0,0 +1,721 @@
+{
+ "42": {
+ "bp": {
+ "log": {
+ "train_loss": [
+ 0.7028828698158264,
+ 0.5380959739049276,
+ 0.4825267461776733,
+ 0.45401095023155214,
+ 0.4324853138923645,
+ 0.41239778510729475,
+ 0.39849929070472717,
+ 0.38395428166389467,
+ 0.3671353324095408,
+ 0.3565144721984863,
+ 0.3396712650140127,
+ 0.32847159377733864,
+ 0.3156646738688151,
+ 0.3063636976877848,
+ 0.29311317729949954,
+ 0.281319753352801,
+ 0.27345897892316184,
+ 0.26703561499913536,
+ 0.26238394471009574,
+ 0.2581080759366353
+ ],
+ "train_acc": [
+ 0.73755,
+ 0.7990666666666667,
+ 0.8184166666666667,
+ 0.8286,
+ 0.8373333333333334,
+ 0.8439833333333333,
+ 0.8508166666666667,
+ 0.85665,
+ 0.8617666666666667,
+ 0.86675,
+ 0.87145,
+ 0.8759333333333333,
+ 0.87975,
+ 0.8850666666666667,
+ 0.8894166666666666,
+ 0.8943333333333333,
+ 0.8967166666666667,
+ 0.8991833333333333,
+ 0.90125,
+ 0.9020333333333334
+ ],
+ "test_acc": [
+ 0.7939,
+ 0.8145,
+ 0.8348,
+ 0.8457,
+ 0.8515,
+ 0.8507,
+ 0.86,
+ 0.8616,
+ 0.8629,
+ 0.867,
+ 0.8749,
+ 0.8758,
+ 0.8749,
+ 0.8826,
+ 0.8871,
+ 0.8893,
+ 0.8887,
+ 0.8914,
+ 0.8918,
+ 0.8933
+ ]
+ },
+ "diagnostics": {
+ "bp_cosine": [
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 1.0
+ ],
+ "perturbation_rho": [
+ 0.999779462814331,
+ 0.9997812509536743,
+ 0.9997234344482422,
+ 0.9995989799499512,
+ 0.9995447993278503,
+ 0.999480128288269,
+ 0.999315083026886,
+ 0.9990702271461487
+ ],
+ "nudging": {
+ "0.001": [
+ -0.0015512802638113499,
+ -0.0014223953476175666,
+ -0.0012828018516302109,
+ -0.0011290921829640865,
+ -0.0010026510572060943,
+ -0.0008882835973054171,
+ -0.0007996053318493068,
+ -0.0007174870697781444
+ ],
+ "0.003": [
+ -0.00463186064735055,
+ -0.004248819313943386,
+ -0.003833592403680086,
+ -0.003375905565917492,
+ -0.0029987930320203304,
+ -0.0026576737873256207,
+ -0.002393170492723584,
+ -0.0021477085538208485
+ ],
+ "0.01": [
+ -0.015182608738541603,
+ -0.013949813321232796,
+ -0.012605881318449974,
+ -0.011121007613837719,
+ -0.009890388697385788,
+ -0.008775115013122559,
+ -0.007910609245300293,
+ -0.0071043153293430805
+ ]
+ }
+ },
+ "drift": {
+ "embed.weight": 2.2156592696361637,
+ "embed.bias": 1.1767430604557616,
+ "blocks.0.ln.weight": 0.10304339230060577,
+ "blocks.0.w1.weight": 1.482853423570453,
+ "blocks.0.w1.bias": 1.6160260440173806,
+ "blocks.0.w2.weight": 4.957564419601416,
+ "blocks.1.ln.weight": 0.10342221707105637,
+ "blocks.1.w1.weight": 1.4925957913309882,
+ "blocks.1.w1.bias": 1.387226988584735,
+ "blocks.1.w2.weight": 4.8677799360581435,
+ "blocks.2.ln.weight": 0.10164804011583328,
+ "blocks.2.w1.weight": 1.5187247212963315,
+ "blocks.2.w1.bias": 1.315003727574603,
+ "blocks.2.w2.weight": 4.892590160655272,
+ "blocks.3.ln.weight": 0.0980648621916771,
+ "blocks.3.w1.weight": 1.5309313086145322,
+ "blocks.3.w1.bias": 1.183681086835403,
+ "blocks.3.w2.weight": 4.853244692493399,
+ "blocks.4.ln.weight": 0.10139396041631699,
+ "blocks.4.w1.weight": 1.4821627671323643,
+ "blocks.4.w1.bias": 1.3175104220848128,
+ "blocks.4.w2.weight": 4.605731857817927,
+ "blocks.5.ln.weight": 0.11211488395929337,
+ "blocks.5.w1.weight": 1.4651151472668016,
+ "blocks.5.w1.bias": 1.2329307722821297,
+ "blocks.5.w2.weight": 4.4285985689227365,
+ "blocks.6.ln.weight": 0.10427453368902206,
+ "blocks.6.w1.weight": 1.4545050310842818,
+ "blocks.6.w1.bias": 1.2120760166211149,
+ "blocks.6.w2.weight": 4.21210005778474,
+ "blocks.7.ln.weight": 0.10458954423666,
+ "blocks.7.w1.weight": 1.4114832006479092,
+ "blocks.7.w1.bias": 1.3654059522290227,
+ "blocks.7.w2.weight": 4.118203814443419,
+ "out_ln.weight": 0.07383835315704346,
+ "out_head.weight": 0.9730150380688232,
+ "out_head.bias": 0.5616520351930565
+ }
+ },
+ "dfa": {
+ "log": {
+ "train_loss": [
+ 1.4100907169977823,
+ 1.436597176615397,
+ 1.4409521081288656,
+ 1.4391240961710612,
+ 1.431752833366394,
+ 1.4278661415100098,
+ 1.42600347849528,
+ 1.4258823367436726,
+ 1.4264747858683269,
+ 1.4247130299250286,
+ 1.4233300720850626,
+ 1.4235623852411905,
+ 1.423257282193502,
+ 1.4203537824630736,
+ 1.4199624996821085,
+ 1.4181565198898316,
+ 1.4212706031799316,
+ 1.4193874025344848,
+ 1.4183136430740357,
+ 1.4177714852015177
+ ],
+ "train_acc": [
+ 0.4805,
+ 0.4955833333333333,
+ 0.48483333333333334,
+ 0.48445,
+ 0.48668333333333336,
+ 0.48556666666666665,
+ 0.48233333333333334,
+ 0.48796666666666666,
+ 0.48468333333333335,
+ 0.48793333333333333,
+ 0.49155,
+ 0.49241666666666667,
+ 0.49341666666666667,
+ 0.4975,
+ 0.49946666666666667,
+ 0.5013666666666666,
+ 0.5008166666666667,
+ 0.5015833333333334,
+ 0.5064833333333333,
+ 0.5055166666666666
+ ],
+ "test_acc": [
+ 0.5175,
+ 0.5227,
+ 0.4511,
+ 0.5054,
+ 0.5238,
+ 0.4415,
+ 0.4969,
+ 0.5415,
+ 0.5148,
+ 0.463,
+ 0.5415,
+ 0.479,
+ 0.5198,
+ 0.5259,
+ 0.5167,
+ 0.5293,
+ 0.5023,
+ 0.513,
+ 0.5234,
+ 0.5205
+ ]
+ },
+ "diagnostics": {
+ "bp_cosine": [
+ 0.18807855248451233,
+ 0.003098251298069954,
+ -0.0009170115226879716,
+ -0.0030592959374189377,
+ 0.001578816445544362,
+ -0.002526880707591772,
+ -0.0002142013981938362,
+ 0.002638747449964285
+ ],
+ "perturbation_rho": [
+ -0.02331582084298134,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.008287552744150162
+ ],
+ "nudging": {
+ "0.001": [
+ -1.016305759549141e-06,
+ 0.0,
+ 0.0,
+ 9.313225746154785e-10,
+ 0.0,
+ -1.862645149230957e-09,
+ 0.0,
+ 0.0
+ ],
+ "0.003": [
+ -3.2687094062566757e-06,
+ 0.0,
+ 0.0,
+ -1.3969838619232178e-09,
+ -9.313225746154785e-10,
+ -2.3283064365386963e-09,
+ -6.984919309616089e-10,
+ 0.0
+ ],
+ "0.01": [
+ -1.1139316484332085e-05,
+ 9.313225746154785e-10,
+ 9.313225746154785e-10,
+ 6.05359673500061e-09,
+ -9.313225746154785e-10,
+ -1.3969838619232178e-09,
+ 1.1641532182693481e-09,
+ -9.313225746154785e-10
+ ]
+ }
+ },
+ "drift": {
+ "embed.weight": 118.42757824386928,
+ "embed.bias": 91.53070095956268,
+ "blocks.0.ln.weight": 6.256274700164795,
+ "blocks.0.w1.weight": 139.8976340759611,
+ "blocks.0.w1.bias": 140.08208587154786,
+ "blocks.0.w2.weight": 332.94308388138757,
+ "blocks.1.ln.weight": 5.70231819152832,
+ "blocks.1.w1.weight": 129.75151109455834,
+ "blocks.1.w1.bias": 121.17432552115143,
+ "blocks.1.w2.weight": 248.7841682293662,
+ "blocks.2.ln.weight": 5.466833114624023,
+ "blocks.2.w1.weight": 124.61655550824435,
+ "blocks.2.w1.bias": 116.83509018118278,
+ "blocks.2.w2.weight": 237.4413053793645,
+ "blocks.3.ln.weight": 5.473834037780762,
+ "blocks.3.w1.weight": 130.23769942327985,
+ "blocks.3.w1.bias": 111.35214850610642,
+ "blocks.3.w2.weight": 253.67927391219192,
+ "blocks.4.ln.weight": 5.401853084564209,
+ "blocks.4.w1.weight": 138.8132811262774,
+ "blocks.4.w1.bias": 138.7305587293139,
+ "blocks.4.w2.weight": 257.0171443869646,
+ "blocks.5.ln.weight": 5.765963554382324,
+ "blocks.5.w1.weight": 155.34176828036547,
+ "blocks.5.w1.bias": 142.51555937014615,
+ "blocks.5.w2.weight": 294.4127013869835,
+ "blocks.6.ln.weight": 5.595874309539795,
+ "blocks.6.w1.weight": 145.1627005952125,
+ "blocks.6.w1.bias": 132.11466551434717,
+ "blocks.6.w2.weight": 265.955180270204,
+ "blocks.7.ln.weight": 5.609050750732422,
+ "blocks.7.w1.weight": 147.06679411967872,
+ "blocks.7.w1.bias": 140.14577658787653,
+ "blocks.7.w2.weight": 258.1442232900185,
+ "out_ln.weight": 1.0355137586593628,
+ "out_head.weight": 8.13612662750539,
+ "out_head.bias": 0.9719656640023888
+ }
+ },
+ "state_bridge": {
+ "log": {
+ "train_loss": [
+ 1.7336509002049765,
+ 1.5758645205815633,
+ 1.9318606907526652,
+ 2.191984078470866,
+ 1.8327842120488484,
+ 1.8950384799957276,
+ 2.377087445449829,
+ 2.1417452257792156,
+ 1.9545116751352947,
+ 2.1638838397979736,
+ 2.023908238474528,
+ 2.336049980545044,
+ 1.9471053196589152,
+ 1.7918469515482585,
+ 1.7594309553146363,
+ 1.7820972992579143,
+ 1.7949362015406292,
+ 1.7877578330993653,
+ 1.7872745515823365,
+ 1.7890100787480672
+ ],
+ "train_acc": [
+ 0.32705,
+ 0.35613333333333336,
+ 0.25498333333333334,
+ 0.14148333333333332,
+ 0.23143333333333332,
+ 0.23958333333333334,
+ 0.07806666666666667,
+ 0.16218333333333335,
+ 0.20036666666666667,
+ 0.1637,
+ 0.23035,
+ 0.13533333333333333,
+ 0.21553333333333333,
+ 0.25556666666666666,
+ 0.29751666666666665,
+ 0.32916666666666666,
+ 0.3303333333333333,
+ 0.33866666666666667,
+ 0.33645,
+ 0.3411
+ ],
+ "test_acc": [
+ 0.4036,
+ 0.3937,
+ 0.2369,
+ 0.2137,
+ 0.201,
+ 0.0282,
+ 0.1012,
+ 0.1994,
+ 0.1525,
+ 0.1935,
+ 0.136,
+ 0.1863,
+ 0.2135,
+ 0.287,
+ 0.3171,
+ 0.3199,
+ 0.3193,
+ 0.309,
+ 0.3431,
+ 0.3284
+ ],
+ "state_pred_error": [
+ 5076234.573784879,
+ 451395873.88586664,
+ 9768031180.663467,
+ 22542289907.438934,
+ 14961824114.824533,
+ 30110774233.224533,
+ 11799262674.944,
+ 64511141264.315735,
+ 46851419274.717865,
+ 25131605056.98987,
+ 35883748757.77707,
+ 48375831022.7968,
+ 69333597113.5488,
+ 40594997328.827736,
+ 42734315291.71627,
+ 38351667650.01387,
+ 32476750280.567467,
+ 28322754527.232,
+ 25330875547.648,
+ 23083374417.237335
+ ]
+ },
+ "diagnostics": {
+ "bp_cosine": [
+ 0.1429949849843979,
+ 0.3180725574493408,
+ 0.20951706171035767,
+ 0.2099049985408783,
+ 0.2332497388124466,
+ 0.18252244591712952,
+ 0.16670912504196167,
+ 0.19962334632873535
+ ],
+ "perturbation_rho": [
+ 0.22952523827552795,
+ 0.06709301471710205,
+ 0.0,
+ 0.02129420079290867,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "nudging": {
+ "0.001": [
+ -2.1068379282951355e-05,
+ -8.009374141693115e-08,
+ 1.862645149230957e-09,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "0.003": [
+ -6.297603249549866e-05,
+ -1.4621764421463013e-07,
+ 2.7939677238464355e-09,
+ -7.450580596923828e-09,
+ 5.587935447692871e-09,
+ 0.0,
+ 9.313225746154785e-10,
+ 0.0
+ ],
+ "0.01": [
+ -0.00020940043032169342,
+ -4.3585896492004395e-07,
+ 0.0,
+ -5.587935447692871e-09,
+ 8.381903171539307e-09,
+ 0.0,
+ -4.6566128730773926e-09,
+ 0.0
+ ]
+ }
+ },
+ "drift": {
+ "embed.weight": 10.2616062216752,
+ "embed.bias": 9.637697583673647,
+ "blocks.0.ln.weight": 2.2154791355133057,
+ "blocks.0.w1.weight": 28.129394231034535,
+ "blocks.0.w1.bias": 29.44761280088024,
+ "blocks.0.w2.weight": 83.23565733006456,
+ "blocks.1.ln.weight": 1.6048834323883057,
+ "blocks.1.w1.weight": 72.41049473988551,
+ "blocks.1.w1.bias": 83.01353544678275,
+ "blocks.1.w2.weight": 118.5292924947854,
+ "blocks.2.ln.weight": 1.528304100036621,
+ "blocks.2.w1.weight": 22.974203066797124,
+ "blocks.2.w1.bias": 24.217216068101656,
+ "blocks.2.w2.weight": 45.236771862572276,
+ "blocks.3.ln.weight": 2.2928571701049805,
+ "blocks.3.w1.weight": 48.83060261437247,
+ "blocks.3.w1.bias": 44.264646679202606,
+ "blocks.3.w2.weight": 77.90843201854017,
+ "blocks.4.ln.weight": 1.569786548614502,
+ "blocks.4.w1.weight": 38.25042383023526,
+ "blocks.4.w1.bias": 46.7989781903236,
+ "blocks.4.w2.weight": 61.850484751580424,
+ "blocks.5.ln.weight": 1.3877884149551392,
+ "blocks.5.w1.weight": 25.667927717045547,
+ "blocks.5.w1.bias": 27.47666495212993,
+ "blocks.5.w2.weight": 56.74081708072433,
+ "blocks.6.ln.weight": 2.3657233715057373,
+ "blocks.6.w1.weight": 68.78499791753522,
+ "blocks.6.w1.bias": 66.7112278827284,
+ "blocks.6.w2.weight": 88.32659623386007,
+ "blocks.7.ln.weight": 1.5114411115646362,
+ "blocks.7.w1.weight": 40.81772220069051,
+ "blocks.7.w1.bias": 45.46224641191049,
+ "blocks.7.w2.weight": 78.43904162151014,
+ "out_ln.weight": 0.27722063660621643,
+ "out_head.weight": 3.4145406581484017,
+ "out_head.bias": 1.9963410657002738
+ }
+ },
+ "credit_bridge": {
+ "log": {
+ "train_loss": [
+ 1.4013389415105184,
+ 1.4349893891016643,
+ 1.442590291341146,
+ 1.4306699399312337,
+ 1.4177788179397584,
+ 1.4436163345336914,
+ 1.5327887171427408,
+ 1.5964955659866333,
+ 1.6370685063680013,
+ 1.6712419691085816,
+ 1.6618358172098795,
+ 1.6343142255147298,
+ 1.6132340278625488,
+ 1.5844669729232788,
+ 1.5616684883117675,
+ 1.541010483233134,
+ 1.525136720085144,
+ 1.5156665041605633,
+ 1.5106076243718465,
+ 1.5088000661214194
+ ],
+ "train_acc": [
+ 0.4805333333333333,
+ 0.48885,
+ 0.4822,
+ 0.48246666666666665,
+ 0.48988333333333334,
+ 0.4848,
+ 0.4645166666666667,
+ 0.44521666666666665,
+ 0.4347,
+ 0.4149833333333333,
+ 0.41485,
+ 0.42995,
+ 0.4308,
+ 0.43275,
+ 0.44,
+ 0.4475,
+ 0.44975,
+ 0.45098333333333335,
+ 0.4527333333333333,
+ 0.45531666666666665
+ ],
+ "test_acc": [
+ 0.4889,
+ 0.4758,
+ 0.4915,
+ 0.5234,
+ 0.5192,
+ 0.4658,
+ 0.4303,
+ 0.4477,
+ 0.4427,
+ 0.4192,
+ 0.4341,
+ 0.4369,
+ 0.4415,
+ 0.4713,
+ 0.4554,
+ 0.4785,
+ 0.4729,
+ 0.47,
+ 0.4671,
+ 0.4741
+ ],
+ "value_loss": [
+ 0.4515702169219653,
+ 0.10145944621364275,
+ 0.0791723535656929,
+ 0.06821208957632383,
+ 0.048990531079967814,
+ 0.04999728363355001,
+ 0.0479251907547315,
+ 0.04561474083264669,
+ 0.03763539131681124,
+ 0.040053354968627296,
+ 0.03465893845955531,
+ 0.029476907690366108,
+ 0.022895141302545864,
+ 0.020180800672868888,
+ 0.01736196913222472,
+ 0.012846514955163002,
+ 0.009126213994373878,
+ 0.008726185441513856,
+ 0.008221083876490592,
+ 0.007363871405273676
+ ]
+ },
+ "diagnostics": {
+ "bp_cosine": [
+ 0.15763823688030243,
+ 0.11252982914447784,
+ 0.23011641204357147,
+ 0.21486178040504456,
+ 0.20616328716278076,
+ 0.2209414690732956,
+ 0.22547751665115356,
+ 0.21971024572849274
+ ],
+ "perturbation_rho": [
+ 0.022450335323810577,
+ 0.03968421369791031,
+ -0.010952746495604515,
+ -0.02441849745810032,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "nudging": {
+ "0.001": [
+ -2.5480985641479492e-06,
+ -3.864988684654236e-08,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "0.003": [
+ -7.3623377829790115e-06,
+ -4.377216100692749e-08,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "0.01": [
+ -2.4445587769150734e-05,
+ -8.102506399154663e-08,
+ -1.862645149230957e-09,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ]
+ }
+ },
+ "drift": {
+ "embed.weight": 56.40174417770624,
+ "embed.bias": 40.65522803248636,
+ "blocks.0.ln.weight": 3.5452845096588135,
+ "blocks.0.w1.weight": 60.29274021947322,
+ "blocks.0.w1.bias": 39.86138087773039,
+ "blocks.0.w2.weight": 158.13629039802228,
+ "blocks.1.ln.weight": 4.1766037940979,
+ "blocks.1.w1.weight": 97.47012476867003,
+ "blocks.1.w1.bias": 94.45373800951319,
+ "blocks.1.w2.weight": 182.69479749878457,
+ "blocks.2.ln.weight": 3.4605705738067627,
+ "blocks.2.w1.weight": 73.73554694789455,
+ "blocks.2.w1.bias": 70.46643111415602,
+ "blocks.2.w2.weight": 154.24816844331818,
+ "blocks.3.ln.weight": 3.458436965942383,
+ "blocks.3.w1.weight": 81.09956449846064,
+ "blocks.3.w1.bias": 70.54884627062913,
+ "blocks.3.w2.weight": 164.27884471330486,
+ "blocks.4.ln.weight": 3.470608949661255,
+ "blocks.4.w1.weight": 84.05857705229029,
+ "blocks.4.w1.bias": 85.61940602638978,
+ "blocks.4.w2.weight": 160.78897537314361,
+ "blocks.5.ln.weight": 3.7341277599334717,
+ "blocks.5.w1.weight": 97.50908269356889,
+ "blocks.5.w1.bias": 91.11903815503673,
+ "blocks.5.w2.weight": 187.97349614569833,
+ "blocks.6.ln.weight": 3.5159947872161865,
+ "blocks.6.w1.weight": 91.22877884822417,
+ "blocks.6.w1.bias": 83.71561937882119,
+ "blocks.6.w2.weight": 167.99172531428107,
+ "blocks.7.ln.weight": 3.4693915843963623,
+ "blocks.7.w1.weight": 88.21253603229287,
+ "blocks.7.w1.bias": 85.04756223959555,
+ "blocks.7.w2.weight": 156.73218429415434,
+ "out_ln.weight": 0.8287110924720764,
+ "out_head.weight": 7.02544389908901,
+ "out_head.bias": 2.6789805309811348
+ }
+ }
+ },
+ "config": {
+ "dataset": "fashionmnist",
+ "d_hidden": 256,
+ "num_blocks": 8,
+ "batch_size": 128,
+ "epochs": 20,
+ "lr": 0.001,
+ "lr_fb": 0.001,
+ "wd": 0.01,
+ "lam": 0.1,
+ "K": 4,
+ "sigma_bridge": 0.05,
+ "ema_momentum": 0.995,
+ "term_grad_weight": 1.0,
+ "seeds": [
+ 42
+ ],
+ "gpu": 0,
+ "output_dir": "results/smoke_test2",
+ "num_classes": 10
+ }
+} \ No newline at end of file
diff --git a/results/toy_lq/state_bridge_seed42.pt b/results/toy_lq/state_bridge_seed42.pt
new file mode 100644
index 0000000..a87e99d
--- /dev/null
+++ b/results/toy_lq/state_bridge_seed42.pt
Binary files differ
diff --git a/results/toy_lq/sweep_results.json b/results/toy_lq/sweep_results.json
new file mode 100644
index 0000000..30c3f49
--- /dev/null
+++ b/results/toy_lq/sweep_results.json
@@ -0,0 +1,1070 @@
+{
+ "base": {
+ "best_cos": 0.28987203165888786,
+ "best_step": 500,
+ "final_cos": -0.0006980087685709199,
+ "final_rho": 0.00831739305673788,
+ "final_nudge": 0.0027815375166634717,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.28987203165888786,
+ "avg_rho": 0.30755721777677536,
+ "avg_nudge": -0.10750458451608817,
+ "loss_term": 0.4910352826118469,
+ "loss_bridge": 0.21142145991325378
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.17731603110829988,
+ "avg_rho": 0.17737891773382822,
+ "avg_nudge": -0.06381315520654122,
+ "loss_term": 0.1414657086133957,
+ "loss_bridge": 0.16202926635742188
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.07300025007377069,
+ "avg_rho": 0.07345063022027414,
+ "avg_nudge": -0.02641519942941765,
+ "loss_term": 0.07946588099002838,
+ "loss_bridge": 0.08628662675619125
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.05074742032835881,
+ "avg_rho": 0.048574333622430764,
+ "avg_nudge": -0.016719758200148743,
+ "loss_term": 0.06007867306470871,
+ "loss_bridge": 0.016275471076369286
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.045629044994711876,
+ "avg_rho": 0.0567939051737388,
+ "avg_nudge": -0.016934571011612814,
+ "loss_term": 0.04579576849937439,
+ "loss_bridge": 0.008496936410665512
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.06145744491368532,
+ "avg_rho": 0.05695042138298353,
+ "avg_nudge": -0.020737762407710154,
+ "loss_term": 0.04881729558110237,
+ "loss_bridge": 0.013217507861554623
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.04631757119204849,
+ "avg_rho": 0.0402289762472113,
+ "avg_nudge": -0.015406294725835323,
+ "loss_term": 0.04255056008696556,
+ "loss_bridge": 0.01808304153382778
+ },
+ {
+ "step": 4000,
+ "avg_cos": 0.07397006265819073,
+ "avg_rho": 0.07162392682706316,
+ "avg_nudge": -0.025235373992472887,
+ "loss_term": 0.05486675351858139,
+ "loss_bridge": 0.02182137593626976
+ },
+ {
+ "step": 4500,
+ "avg_cos": 0.07251314166933298,
+ "avg_rho": 0.07301904633641243,
+ "avg_nudge": -0.027378834628810484,
+ "loss_term": 0.027670202776789665,
+ "loss_bridge": 0.013123266398906708
+ },
+ {
+ "step": 5000,
+ "avg_cos": -0.0006980087685709199,
+ "avg_rho": 0.00831739305673788,
+ "avg_nudge": 0.0027815375166634717,
+ "loss_term": 0.021854262799024582,
+ "loss_bridge": 0.0122066093608737
+ }
+ ]
+ },
+ "noise_0.1": {
+ "best_cos": 0.28885648772120476,
+ "best_step": 500,
+ "final_cos": 0.02330129671220978,
+ "final_rho": 0.03084004654859503,
+ "final_nudge": -0.005932313855737448,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.28885648772120476,
+ "avg_rho": 0.306567445397377,
+ "avg_nudge": -0.1071426725635926,
+ "loss_term": 0.4925232529640198,
+ "loss_bridge": 0.20981840789318085
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.17456108083327612,
+ "avg_rho": 0.17607956007122993,
+ "avg_nudge": -0.06276426805804174,
+ "loss_term": 0.1324017345905304,
+ "loss_bridge": 0.1835721731185913
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.08755840392162402,
+ "avg_rho": 0.08556619860852759,
+ "avg_nudge": -0.03154423305143913,
+ "loss_term": 0.07669232785701752,
+ "loss_bridge": 0.10240059345960617
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.05209386876473824,
+ "avg_rho": 0.04985952338514229,
+ "avg_nudge": -0.01653688432027896,
+ "loss_term": 0.06218753010034561,
+ "loss_bridge": 0.02190624177455902
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.038306045811623335,
+ "avg_rho": 0.047700356847296156,
+ "avg_nudge": -0.01320391776971519,
+ "loss_term": 0.06128765642642975,
+ "loss_bridge": 0.01958899199962616
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.059713449950019516,
+ "avg_rho": 0.054914615117013454,
+ "avg_nudge": -0.01939693869402011,
+ "loss_term": 0.04238733649253845,
+ "loss_bridge": 0.006958359386771917
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.04283027160757532,
+ "avg_rho": 0.028526597811530035,
+ "avg_nudge": -0.01381517636279265,
+ "loss_term": 0.024001002311706543,
+ "loss_bridge": 0.006152431480586529
+ },
+ {
+ "step": 4000,
+ "avg_cos": 0.0656419579560558,
+ "avg_rho": 0.06320031352030735,
+ "avg_nudge": -0.02216823499960204,
+ "loss_term": 0.04236245155334473,
+ "loss_bridge": 0.01628262549638748
+ },
+ {
+ "step": 4500,
+ "avg_cos": 0.05787398883452018,
+ "avg_rho": 0.0610881638713181,
+ "avg_nudge": -0.02213552314788103,
+ "loss_term": 0.024835661053657532,
+ "loss_bridge": 0.008629711344838142
+ },
+ {
+ "step": 5000,
+ "avg_cos": 0.02330129671220978,
+ "avg_rho": 0.03084004654859503,
+ "avg_nudge": -0.005932313855737448,
+ "loss_term": 0.035377949476242065,
+ "loss_bridge": 0.020307490602135658
+ }
+ ]
+ },
+ "noise_0.3": {
+ "best_cos": 0.28320933257540065,
+ "best_step": 500,
+ "final_cos": -0.007130145425132166,
+ "final_rho": 0.003303757131410142,
+ "final_nudge": 0.004990910800794761,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.28320933257540065,
+ "avg_rho": 0.30105741570393246,
+ "avg_nudge": -0.10515742873152097,
+ "loss_term": 0.44116002321243286,
+ "loss_bridge": 0.210750013589859
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.1680816188454628,
+ "avg_rho": 0.1721916707853476,
+ "avg_nudge": -0.060552582144737244,
+ "loss_term": 0.14207889139652252,
+ "loss_bridge": 0.16882070899009705
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.08877109829336405,
+ "avg_rho": 0.08808179199695587,
+ "avg_nudge": -0.03221472934819758,
+ "loss_term": 0.12282441556453705,
+ "loss_bridge": 0.07336755841970444
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.04903770827998718,
+ "avg_rho": 0.04795513402981063,
+ "avg_nudge": -0.016069856472313404,
+ "loss_term": 0.06662434339523315,
+ "loss_bridge": 0.02520526573061943
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.05001032492145896,
+ "avg_rho": 0.053442057532568775,
+ "avg_nudge": -0.01783018947268526,
+ "loss_term": 0.04614880681037903,
+ "loss_bridge": 0.015394347719848156
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.03979540191357955,
+ "avg_rho": 0.03453826089389622,
+ "avg_nudge": -0.012916556559503078,
+ "loss_term": 0.04155049845576286,
+ "loss_bridge": 0.007860729470849037
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.06703585013747215,
+ "avg_rho": 0.04949762811884284,
+ "avg_nudge": -0.02297227829694748,
+ "loss_term": 0.059078700840473175,
+ "loss_bridge": 0.030750930309295654
+ },
+ {
+ "step": 4000,
+ "avg_cos": 0.06641693723698457,
+ "avg_rho": 0.06286459214364488,
+ "avg_nudge": -0.021883966866880655,
+ "loss_term": 0.025344880297780037,
+ "loss_bridge": 0.008583566173911095
+ },
+ {
+ "step": 4500,
+ "avg_cos": 0.03687915485352278,
+ "avg_rho": 0.028895257118468482,
+ "avg_nudge": -0.01225966370354096,
+ "loss_term": 0.020479349419474602,
+ "loss_bridge": 0.00925756897777319
+ },
+ {
+ "step": 5000,
+ "avg_cos": -0.007130145425132166,
+ "avg_rho": 0.003303757131410142,
+ "avg_nudge": 0.004990910800794761,
+ "loss_term": 0.03882830590009689,
+ "loss_bridge": 0.022515198215842247
+ }
+ ]
+ },
+ "lam_1.0": {
+ "best_cos": 0.2899630069732666,
+ "best_step": 500,
+ "final_cos": 0.0024576462844076255,
+ "final_rho": 0.010072723651925722,
+ "final_nudge": 0.0017782459035515785,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.2899630069732666,
+ "avg_rho": 0.3076498980323474,
+ "avg_nudge": -0.10753695170084636,
+ "loss_term": 0.49082812666893005,
+ "loss_bridge": 0.21152979135513306
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.17747302974263826,
+ "avg_rho": 0.17881879458824793,
+ "avg_nudge": -0.06420790683478117,
+ "loss_term": 0.1407778263092041,
+ "loss_bridge": 0.15578678250312805
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.08065215001503627,
+ "avg_rho": 0.08029377926141024,
+ "avg_nudge": -0.029326035796354216,
+ "loss_term": 0.07120706140995026,
+ "loss_bridge": 0.09130121767520905
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.0563324602941672,
+ "avg_rho": 0.058026942308060825,
+ "avg_nudge": -0.01875296530003349,
+ "loss_term": 0.06783974170684814,
+ "loss_bridge": 0.01712188683450222
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.039659525423000254,
+ "avg_rho": 0.04613731553157171,
+ "avg_nudge": -0.013568413676694036,
+ "loss_term": 0.0634445995092392,
+ "loss_bridge": 0.02091756835579872
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.04386034506993989,
+ "avg_rho": 0.039356538482631244,
+ "avg_nudge": -0.01407914562150836,
+ "loss_term": 0.05031818151473999,
+ "loss_bridge": 0.01259439717978239
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.0481002099889641,
+ "avg_rho": 0.03689269566287597,
+ "avg_nudge": -0.01590493693947792,
+ "loss_term": 0.03421805799007416,
+ "loss_bridge": 0.014973677694797516
+ },
+ {
+ "step": 4000,
+ "avg_cos": 0.07999587121109168,
+ "avg_rho": 0.07683291751891375,
+ "avg_nudge": -0.02731190746029218,
+ "loss_term": 0.04726963862776756,
+ "loss_bridge": 0.017824556678533554
+ },
+ {
+ "step": 4500,
+ "avg_cos": 0.07246351769814889,
+ "avg_rho": 0.07343382605661948,
+ "avg_nudge": -0.027415843835721414,
+ "loss_term": 0.02768528461456299,
+ "loss_bridge": 0.013396943919360638
+ },
+ {
+ "step": 5000,
+ "avg_cos": 0.0024576462844076255,
+ "avg_rho": 0.010072723651925722,
+ "avg_nudge": 0.0017782459035515785,
+ "loss_term": 0.02114972099661827,
+ "loss_bridge": 0.011466547846794128
+ }
+ ]
+ },
+ "noise_lam": {
+ "best_cos": 0.28980905935168266,
+ "best_step": 500,
+ "final_cos": 0.000333610107190907,
+ "final_rho": 0.009229002442831794,
+ "final_nudge": 0.0020260093733668327,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.28980905935168266,
+ "avg_rho": 0.30755361666282016,
+ "avg_nudge": -0.10748158146937688,
+ "loss_term": 0.4891643524169922,
+ "loss_bridge": 0.21093463897705078
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.1692807301878929,
+ "avg_rho": 0.17357065031925836,
+ "avg_nudge": -0.060928904761870704,
+ "loss_term": 0.12851648032665253,
+ "loss_bridge": 0.1991802453994751
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.08976124723752339,
+ "avg_rho": 0.08664474201699097,
+ "avg_nudge": -0.032548267083863415,
+ "loss_term": 0.08289942145347595,
+ "loss_bridge": 0.08361449092626572
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.0449219069754084,
+ "avg_rho": 0.03701058775186539,
+ "avg_nudge": -0.013820620176071921,
+ "loss_term": 0.07352523505687714,
+ "loss_bridge": 0.031066572293639183
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.0341803894067804,
+ "avg_rho": 0.04086895197785149,
+ "avg_nudge": -0.011834259377792478,
+ "loss_term": 0.05221429467201233,
+ "loss_bridge": 0.013445280492305756
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.043432267770792045,
+ "avg_rho": 0.035848827101290226,
+ "avg_nudge": -0.0138394293996195,
+ "loss_term": 0.08377102017402649,
+ "loss_bridge": 0.02437596395611763
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.04308730812044814,
+ "avg_rho": 0.02904196917855491,
+ "avg_nudge": -0.013877601828426123,
+ "loss_term": 0.028318829834461212,
+ "loss_bridge": 0.0096125528216362
+ },
+ {
+ "step": 4000,
+ "avg_cos": 0.07217423028002183,
+ "avg_rho": 0.07531639956869185,
+ "avg_nudge": -0.024754961564516027,
+ "loss_term": 0.04526568949222565,
+ "loss_bridge": 0.01877186819911003
+ },
+ {
+ "step": 4500,
+ "avg_cos": 0.041415012208744884,
+ "avg_rho": 0.039630077119606234,
+ "avg_nudge": -0.013629865366965532,
+ "loss_term": 0.0305576603859663,
+ "loss_bridge": 0.011333338916301727
+ },
+ {
+ "step": 5000,
+ "avg_cos": 0.000333610107190907,
+ "avg_rho": 0.009229002442831794,
+ "avg_nudge": 0.0020260093733668327,
+ "loss_term": 0.020293015986680984,
+ "loss_bridge": 0.00792770553380251
+ }
+ ]
+ },
+ "no_ln": {
+ "best_cos": 0.2994285201032956,
+ "best_step": 500,
+ "final_cos": -0.027601251068214577,
+ "final_rho": -0.03011056105606258,
+ "final_nudge": 0.013365049380809069,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.2994285201032956,
+ "avg_rho": 0.3287110353509585,
+ "avg_nudge": -0.11226618165771167,
+ "loss_term": 0.5378487706184387,
+ "loss_bridge": 0.20209679007530212
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.23427631705999374,
+ "avg_rho": 0.24483238657315573,
+ "avg_nudge": -0.08577251620590687,
+ "loss_term": 0.14587438106536865,
+ "loss_bridge": 0.17536549270153046
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.10971186744670074,
+ "avg_rho": 0.1081712432205677,
+ "avg_nudge": -0.039541066934665046,
+ "loss_term": 0.09813931584358215,
+ "loss_bridge": 0.1144137978553772
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.17568999342620373,
+ "avg_rho": 0.18043102137744427,
+ "avg_nudge": -0.06037815675760309,
+ "loss_term": 0.12119113653898239,
+ "loss_bridge": 0.05170433223247528
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.1501085925847292,
+ "avg_rho": 0.1694059126699964,
+ "avg_nudge": -0.05651436994473139,
+ "loss_term": 0.07073387503623962,
+ "loss_bridge": 0.015867799520492554
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.10612630782028039,
+ "avg_rho": 0.09910948442605634,
+ "avg_nudge": -0.034841354160259165,
+ "loss_term": 0.05772688612341881,
+ "loss_bridge": 0.0220709890127182
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.12348712608218193,
+ "avg_rho": 0.12745930068194866,
+ "avg_nudge": -0.04433920063699285,
+ "loss_term": 0.04150111600756645,
+ "loss_bridge": 0.014408521354198456
+ },
+ {
+ "step": 4000,
+ "avg_cos": 0.012287488003494218,
+ "avg_rho": 0.038615713633286454,
+ "avg_nudge": -0.006448045140132308,
+ "loss_term": 0.04848453402519226,
+ "loss_bridge": 0.026737889274954796
+ },
+ {
+ "step": 4500,
+ "avg_cos": 0.0344570055603981,
+ "avg_rho": 0.043244189505154886,
+ "avg_nudge": -0.012965732564528784,
+ "loss_term": 0.04390523582696915,
+ "loss_bridge": 0.016968518495559692
+ },
+ {
+ "step": 5000,
+ "avg_cos": -0.027601251068214577,
+ "avg_rho": -0.03011056105606258,
+ "avg_nudge": 0.013365049380809069,
+ "loss_term": 0.05419892817735672,
+ "loss_bridge": 0.029718749225139618
+ }
+ ]
+ },
+ "big_vnet": {
+ "best_cos": 0.25947993124524754,
+ "best_step": 500,
+ "final_cos": 0.012725223108039549,
+ "final_rho": -0.0029713623225688934,
+ "final_nudge": -0.0007455882926781973,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.25947993124524754,
+ "avg_rho": 0.2872927797337373,
+ "avg_nudge": -0.09759780826667945,
+ "loss_term": 0.24058431386947632,
+ "loss_bridge": 0.20110812783241272
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.11903308952848117,
+ "avg_rho": 0.10757205138603847,
+ "avg_nudge": -0.04069022353117665,
+ "loss_term": 0.1535366326570511,
+ "loss_bridge": 0.10731971263885498
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.04738504672423005,
+ "avg_rho": 0.04187516961246729,
+ "avg_nudge": -0.01667174060518543,
+ "loss_term": 0.12195796519517899,
+ "loss_bridge": 0.03276967257261276
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.05584627948701382,
+ "avg_rho": 0.06464844088380535,
+ "avg_nudge": -0.019484267104417086,
+ "loss_term": 0.07039390504360199,
+ "loss_bridge": 0.023653611540794373
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.11334512817362945,
+ "avg_rho": 0.13342145457863808,
+ "avg_nudge": -0.04256325137491027,
+ "loss_term": 0.09321287274360657,
+ "loss_bridge": 0.03603611886501312
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.07876436489944656,
+ "avg_rho": 0.08206061793801685,
+ "avg_nudge": -0.02565166230003039,
+ "loss_term": 0.03433217480778694,
+ "loss_bridge": 0.014776414260268211
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.059695989514390625,
+ "avg_rho": 0.043808821588754654,
+ "avg_nudge": -0.01833860871071617,
+ "loss_term": 0.07867280393838882,
+ "loss_bridge": 0.04651641845703125
+ },
+ {
+ "step": 4000,
+ "avg_cos": 0.008286381295571724,
+ "avg_rho": 0.01973852072842419,
+ "avg_nudge": -0.0019318210737158854,
+ "loss_term": 0.03502834588289261,
+ "loss_bridge": 0.011813998222351074
+ },
+ {
+ "step": 4500,
+ "avg_cos": -0.006770744492920737,
+ "avg_rho": 2.3505534045398235e-05,
+ "avg_nudge": 0.002897862965861956,
+ "loss_term": 0.04147114232182503,
+ "loss_bridge": 0.026934277266263962
+ },
+ {
+ "step": 5000,
+ "avg_cos": 0.012725223108039549,
+ "avg_rho": -0.0029713623225688934,
+ "avg_nudge": -0.0007455882926781973,
+ "loss_term": 0.038749027997255325,
+ "loss_bridge": 0.01941092312335968
+ }
+ ]
+ },
+ "ema_0.999": {
+ "best_cos": 0.10180055970946948,
+ "best_step": 1000,
+ "final_cos": -0.01584776126158734,
+ "final_rho": -0.01703926082700491,
+ "final_nudge": 0.007290713644276063,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": -0.005628384804974,
+ "avg_rho": 0.010925033983464042,
+ "avg_nudge": 0.0015428058492640655,
+ "loss_term": 0.5920301675796509,
+ "loss_bridge": 1.4890536069869995
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.10180055970946948,
+ "avg_rho": 0.10515290250380833,
+ "avg_nudge": -0.037424925404290356,
+ "loss_term": 0.5715007185935974,
+ "loss_bridge": 0.4665977954864502
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.021814276037427287,
+ "avg_rho": 0.003451728650058309,
+ "avg_nudge": -0.005690907438596089,
+ "loss_term": 0.2616257071495056,
+ "loss_bridge": 0.2758212387561798
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.03669632730695108,
+ "avg_rho": 0.036930523036668696,
+ "avg_nudge": -0.01006293793519338,
+ "loss_term": 0.11164076626300812,
+ "loss_bridge": 0.1407940685749054
+ },
+ {
+ "step": 2500,
+ "avg_cos": -0.020325756592986483,
+ "avg_rho": -0.027848235641916592,
+ "avg_nudge": 0.011153894321372112,
+ "loss_term": 0.15471391379833221,
+ "loss_bridge": 0.06181420385837555
+ },
+ {
+ "step": 3000,
+ "avg_cos": -0.0060501456415901584,
+ "avg_rho": -0.013405103546877703,
+ "avg_nudge": 0.005228333951284488,
+ "loss_term": 0.07506504654884338,
+ "loss_bridge": 0.08207326382398605
+ },
+ {
+ "step": 3500,
+ "avg_cos": -0.02149865326161186,
+ "avg_rho": -0.023167532014970977,
+ "avg_nudge": 0.010358475303898254,
+ "loss_term": 0.048137497156858444,
+ "loss_bridge": 0.03276998922228813
+ },
+ {
+ "step": 4000,
+ "avg_cos": -0.007104064881180723,
+ "avg_rho": -0.00720199760204802,
+ "avg_nudge": 0.005406570465614398,
+ "loss_term": 0.03773114085197449,
+ "loss_bridge": 0.037960827350616455
+ },
+ {
+ "step": 4500,
+ "avg_cos": -0.0034141440119128674,
+ "avg_rho": -0.003708663280121982,
+ "avg_nudge": 0.002461720102777084,
+ "loss_term": 0.0416095145046711,
+ "loss_bridge": 0.03200625628232956
+ },
+ {
+ "step": 5000,
+ "avg_cos": -0.01584776126158734,
+ "avg_rho": -0.01703926082700491,
+ "avg_nudge": 0.007290713644276063,
+ "loss_term": 0.023291591554880142,
+ "loss_bridge": 0.026996765285730362
+ }
+ ]
+ },
+ "K16": {
+ "best_cos": 0.3187306026617686,
+ "best_step": 500,
+ "final_cos": 0.012039402422184745,
+ "final_rho": -0.002023946028202772,
+ "final_nudge": -0.0010173787983755271,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.3187306026617686,
+ "avg_rho": 0.3298306291302045,
+ "avg_nudge": -0.12337777391076088,
+ "loss_term": 0.3945310413837433,
+ "loss_bridge": 0.27078020572662354
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.15608959334592024,
+ "avg_rho": 0.14268888781468073,
+ "avg_nudge": -0.05713697926451763,
+ "loss_term": 0.13724187016487122,
+ "loss_bridge": 0.13912135362625122
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.07919560000300407,
+ "avg_rho": 0.08190769484887521,
+ "avg_nudge": -0.029933936273058254,
+ "loss_term": 0.08082282543182373,
+ "loss_bridge": 0.07666538655757904
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.06641554242620866,
+ "avg_rho": 0.055935436549286045,
+ "avg_nudge": -0.020354578581949074,
+ "loss_term": 0.07134468108415604,
+ "loss_bridge": 0.011993557214736938
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.0844428426741312,
+ "avg_rho": 0.0917752521733443,
+ "avg_nudge": -0.028687965202455718,
+ "loss_term": 0.0323462039232254,
+ "loss_bridge": 0.00667245127260685
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.0888429010907809,
+ "avg_rho": 0.06338833862294753,
+ "avg_nudge": -0.026075587142258883,
+ "loss_term": 0.04170331731438637,
+ "loss_bridge": 0.010068882256746292
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.016165781184099615,
+ "avg_rho": 0.014142975055923065,
+ "avg_nudge": -0.003436643397435546,
+ "loss_term": 0.03528498113155365,
+ "loss_bridge": 0.012309195473790169
+ },
+ {
+ "step": 4000,
+ "avg_cos": 0.06803990031282107,
+ "avg_rho": 0.055018783935035266,
+ "avg_nudge": -0.022906929564972717,
+ "loss_term": 0.027342472225427628,
+ "loss_bridge": 0.007336604408919811
+ },
+ {
+ "step": 4500,
+ "avg_cos": 0.024216643689821165,
+ "avg_rho": 0.04094697698019445,
+ "avg_nudge": -0.008364889615525803,
+ "loss_term": 0.027580715715885162,
+ "loss_bridge": 0.016118880361318588
+ },
+ {
+ "step": 5000,
+ "avg_cos": 0.012039402422184745,
+ "avg_rho": -0.002023946028202772,
+ "avg_nudge": -0.0010173787983755271,
+ "loss_term": 0.027855150401592255,
+ "loss_bridge": 0.010500052943825722
+ }
+ ]
+ },
+ "best_combo": {
+ "best_cos": 0.30479515840609867,
+ "best_step": 500,
+ "final_cos": -0.025737087552746136,
+ "final_rho": -0.01576789258979261,
+ "final_nudge": 0.011260819776604572,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.30479515840609867,
+ "avg_rho": 0.33246727536122006,
+ "avg_nudge": -0.11396304952601592,
+ "loss_term": 0.5129790306091309,
+ "loss_bridge": 0.21324321627616882
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.24110793943206468,
+ "avg_rho": 0.24976263443628946,
+ "avg_nudge": -0.08793549550076325,
+ "loss_term": 0.14881110191345215,
+ "loss_bridge": 0.1860560178756714
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.12106851922969024,
+ "avg_rho": 0.11747027436892192,
+ "avg_nudge": -0.042971268917123474,
+ "loss_term": 0.10358402132987976,
+ "loss_bridge": 0.11228226125240326
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.19137668733795485,
+ "avg_rho": 0.2001398652791977,
+ "avg_nudge": -0.06597264126564066,
+ "loss_term": 0.0836295336484909,
+ "loss_bridge": 0.03368356451392174
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.1382010855789607,
+ "avg_rho": 0.15255235826286176,
+ "avg_nudge": -0.05140705577408274,
+ "loss_term": 0.058304790407419205,
+ "loss_bridge": 0.014804087579250336
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.0815443117171526,
+ "avg_rho": 0.07958398557578523,
+ "avg_nudge": -0.026181443439175684,
+ "loss_term": 0.059965550899505615,
+ "loss_bridge": 0.016236742958426476
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.09519872162491083,
+ "avg_rho": 0.09258671143713097,
+ "avg_nudge": -0.03418003007148703,
+ "loss_term": 0.03956954926252365,
+ "loss_bridge": 0.014235305599868298
+ },
+ {
+ "step": 4000,
+ "avg_cos": -0.011209671385586262,
+ "avg_rho": -0.007080828654579818,
+ "avg_nudge": 0.004587161975602309,
+ "loss_term": 0.03683673217892647,
+ "loss_bridge": 0.018107034265995026
+ },
+ {
+ "step": 4500,
+ "avg_cos": 0.04608155476550261,
+ "avg_rho": 0.05846519426753124,
+ "avg_nudge": -0.017059470837314922,
+ "loss_term": 0.043314699083566666,
+ "loss_bridge": 0.023334285244345665
+ },
+ {
+ "step": 5000,
+ "avg_cos": -0.025737087552746136,
+ "avg_rho": -0.01576789258979261,
+ "avg_nudge": 0.011260819776604572,
+ "loss_term": 0.03279898688197136,
+ "loss_bridge": 0.015516506507992744
+ }
+ ]
+ },
+ "noise_1.0": {
+ "best_cos": 0.2831856335202853,
+ "best_step": 500,
+ "final_cos": 0.010971122809375325,
+ "final_rho": 0.014127362189659229,
+ "final_nudge": -0.002180874968568484,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.2831856335202853,
+ "avg_rho": 0.30480503539244336,
+ "avg_nudge": -0.1052918794254462,
+ "loss_term": 0.4858350455760956,
+ "loss_bridge": 0.20565161108970642
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.14922113033632436,
+ "avg_rho": 0.14803783098856607,
+ "avg_nudge": -0.052950371988117695,
+ "loss_term": 0.1246427372097969,
+ "loss_bridge": 0.1951574981212616
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.032916761857147016,
+ "avg_rho": 0.030339293957998354,
+ "avg_nudge": -0.011420852970331907,
+ "loss_term": 0.059441905468702316,
+ "loss_bridge": 0.09553220868110657
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.02735897192421059,
+ "avg_rho": 0.019464978327353794,
+ "avg_nudge": -0.008191200438886881,
+ "loss_term": 0.07519456744194031,
+ "loss_bridge": 0.017568862065672874
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.02141591941472143,
+ "avg_rho": 0.027642763530214626,
+ "avg_nudge": -0.007318320373694102,
+ "loss_term": 0.05027623474597931,
+ "loss_bridge": 0.014622311107814312
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.048597725903770574,
+ "avg_rho": 0.03585781451935569,
+ "avg_nudge": -0.015588213689625263,
+ "loss_term": 0.06457968056201935,
+ "loss_bridge": 0.01680811122059822
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.06355043267831206,
+ "avg_rho": 0.04236283013597131,
+ "avg_nudge": -0.021118728754421074,
+ "loss_term": 0.035970453172922134,
+ "loss_bridge": 0.015534179285168648
+ },
+ {
+ "step": 4000,
+ "avg_cos": 0.04567302499587337,
+ "avg_rho": 0.05423017560193936,
+ "avg_nudge": -0.01528523334612449,
+ "loss_term": 0.034011729061603546,
+ "loss_bridge": 0.009947280399501324
+ },
+ {
+ "step": 4500,
+ "avg_cos": 0.024254882785802085,
+ "avg_rho": 0.023739464891453583,
+ "avg_nudge": -0.009648310175786415,
+ "loss_term": 0.04320281371474266,
+ "loss_bridge": 0.01757601648569107
+ },
+ {
+ "step": 5000,
+ "avg_cos": 0.010971122809375325,
+ "avg_rho": 0.014127362189659229,
+ "avg_nudge": -0.002180874968568484,
+ "loss_term": 0.024934478104114532,
+ "loss_bridge": 0.011717695742845535
+ }
+ ]
+ },
+ "lr_3e-4": {
+ "best_cos": 0.6101815849542618,
+ "best_step": 500,
+ "final_cos": -0.008048781737064322,
+ "final_rho": -0.026018392760306597,
+ "final_nudge": 0.007571722225596507,
+ "history": [
+ {
+ "step": 500,
+ "avg_cos": 0.6101815849542618,
+ "avg_rho": 0.620405301451683,
+ "avg_nudge": -0.2231019102036953,
+ "loss_term": 3.107595205307007,
+ "loss_bridge": 2.7781143188476562
+ },
+ {
+ "step": 1000,
+ "avg_cos": 0.39656415830055874,
+ "avg_rho": 0.39799897621075314,
+ "avg_nudge": -0.14455867062012354,
+ "loss_term": 0.29117679595947266,
+ "loss_bridge": 0.13909414410591125
+ },
+ {
+ "step": 1500,
+ "avg_cos": 0.26989879210789997,
+ "avg_rho": 0.2640038679043452,
+ "avg_nudge": -0.09908402090271314,
+ "loss_term": 0.1437155306339264,
+ "loss_bridge": 0.06845193356275558
+ },
+ {
+ "step": 2000,
+ "avg_cos": 0.15282577524582544,
+ "avg_rho": 0.1327554533878962,
+ "avg_nudge": -0.05070468131452799,
+ "loss_term": 0.10841675102710724,
+ "loss_bridge": 0.044324424117803574
+ },
+ {
+ "step": 2500,
+ "avg_cos": 0.054395756063361965,
+ "avg_rho": 0.04450509278103709,
+ "avg_nudge": -0.01604464929550886,
+ "loss_term": 0.09302366524934769,
+ "loss_bridge": 0.016296718269586563
+ },
+ {
+ "step": 3000,
+ "avg_cos": 0.041961303912103176,
+ "avg_rho": 0.032989607813457646,
+ "avg_nudge": -0.01202150775740544,
+ "loss_term": 0.06656567752361298,
+ "loss_bridge": 0.0089980224147439
+ },
+ {
+ "step": 3500,
+ "avg_cos": 0.010103868204168975,
+ "avg_rho": -0.01982816867530346,
+ "avg_nudge": 0.0005857348442077637,
+ "loss_term": 0.0479322224855423,
+ "loss_bridge": 0.004413220100104809
+ },
+ {
+ "step": 4000,
+ "avg_cos": 0.028387469239532948,
+ "avg_rho": 0.012580555553237597,
+ "avg_nudge": -0.006568876715997855,
+ "loss_term": 0.06536682695150375,
+ "loss_bridge": 0.005300351418554783
+ },
+ {
+ "step": 4500,
+ "avg_cos": 0.015142000513151288,
+ "avg_rho": 0.019022303090120356,
+ "avg_nudge": -0.004821705476691325,
+ "loss_term": 0.037237197160720825,
+ "loss_bridge": 0.005878066644072533
+ },
+ {
+ "step": 5000,
+ "avg_cos": -0.008048781737064322,
+ "avg_rho": -0.026018392760306597,
+ "avg_nudge": 0.007571722225596507,
+ "loss_term": 0.02460472844541073,
+ "loss_bridge": 0.005403056740760803
+ }
+ ]
+ }
+} \ No newline at end of file
diff --git a/results/toy_lq/toy_lq_seed42.json b/results/toy_lq/toy_lq_seed42.json
new file mode 100644
index 0000000..0a821be
--- /dev/null
+++ b/results/toy_lq/toy_lq_seed42.json
@@ -0,0 +1,335 @@
+{
+ "config": {
+ "d_hidden": 64,
+ "output_dim": 10,
+ "num_layers": 12,
+ "sigma": 0.03,
+ "batch_size": 256,
+ "num_steps": 5000,
+ "lr_fb": 0.001,
+ "lam": 0.1,
+ "K": 8,
+ "ema_momentum": 0.995,
+ "sigma_bridge": 0.03,
+ "eval_every": 500,
+ "seed": 42,
+ "gpu": 0,
+ "output_dir": "results/toy_lq"
+ },
+ "log": {
+ "steps": [
+ 1,
+ 500,
+ 1000,
+ 1500,
+ 2000,
+ 2500,
+ 3000,
+ 3500,
+ 4000,
+ 4500,
+ 5000
+ ],
+ "state_bridge_loss": [
+ 66.01814270019531,
+ 2.058396816253662,
+ 2.191567897796631,
+ 2.1112077236175537,
+ 2.2748100757598877,
+ 2.1154427528381348,
+ 2.0553503036499023,
+ 1.966332197189331,
+ 2.1147494316101074,
+ 2.007577896118164,
+ 2.0389387607574463
+ ],
+ "credit_bridge_loss": [
+ 112.22884368896484,
+ 0.5559965372085571,
+ 0.36617255210876465,
+ 0.22275042533874512,
+ 0.13448813557624817,
+ 0.0859837457537651,
+ 0.06448937207460403,
+ 0.03930438682436943,
+ 0.033317387104034424,
+ 0.03704077750444412,
+ 0.02309737727046013
+ ],
+ "dfa_costate_cos": [
+ 0.0011988391992277824,
+ 0.007705975646296373,
+ -0.0006242827870524847,
+ 0.0037568132393062115,
+ 0.004209253507164808,
+ -0.0012397096635630499,
+ 0.0033803660816584644,
+ 0.003754911944270134,
+ 0.002183924035097544,
+ 0.0057398807973815845,
+ 0.0028941635615550554
+ ],
+ "state_costate_cos": [
+ 0.009303608121207127,
+ 0.9451924241506137,
+ 0.9427064611361577,
+ 0.9456643370481638,
+ 0.947175892499777,
+ 0.9486432488148029,
+ 0.9405259856810937,
+ 0.9429925313362708,
+ 0.9463646090947665,
+ 0.9434385620630704,
+ 0.9468063895518963
+ ],
+ "credit_costate_cos": [
+ 0.03412332414434506,
+ 0.3051199523302225,
+ 0.2641584941974053,
+ 0.17990898627501267,
+ 0.11904546274588658,
+ 0.03094297769264533,
+ 0.022915477577883463,
+ 0.008693795901938127,
+ 0.0027211602920523058,
+ 0.020937439484091904,
+ 0.033342053540624104
+ ],
+ "dfa_rho": [
+ 0.005011526596111556,
+ 0.0020135376447190842,
+ -0.011304221504057447,
+ 0.003935044708972176,
+ 0.0159942601264144,
+ -0.011545649264007807,
+ 0.01096861291443929,
+ 0.0007782066240906715,
+ -0.015019190264865756,
+ 0.007689292387415965,
+ 0.006176682732378443
+ ],
+ "state_rho": [
+ 0.011923154350370169,
+ 0.9337877084811529,
+ 0.9292550335327784,
+ 0.9341330577929815,
+ 0.9364036619663239,
+ 0.9322425921758016,
+ 0.9240961174170176,
+ 0.9329939633607864,
+ 0.9331585764884949,
+ 0.9324665367603302,
+ 0.9281178514162699
+ ],
+ "credit_rho": [
+ 0.031667908265565835,
+ 0.31275976697603863,
+ 0.2433429310719172,
+ 0.1816188059747219,
+ 0.11667139704028766,
+ 0.013198353117331862,
+ 0.022044080891646445,
+ -0.007547003333456814,
+ -0.011566813724736372,
+ -0.003948230994865298,
+ 0.016950203105807304
+ ],
+ "dfa_nudge": [
+ -0.0003799900102118651,
+ -0.0025626374408602715,
+ 0.0017628272374471028,
+ -0.001205168974896272,
+ -0.0011821148606638114,
+ 0.0014717701512078445,
+ -8.787959814071655e-05,
+ -0.0006076549955954155,
+ 0.0005303900688886642,
+ -0.0014991160326947768,
+ -0.0005284918782611688
+ ],
+ "state_nudge": [
+ -0.002327537008871635,
+ -0.34574924657742184,
+ -0.3358767156799634,
+ -0.33698513607184094,
+ -0.3561149264375369,
+ -0.33514803399642307,
+ -0.3557068184018135,
+ -0.3218521823485692,
+ -0.33325668424367905,
+ -0.34358637283245724,
+ -0.34090926001469296
+ ],
+ "credit_nudge": [
+ -0.014598140881086389,
+ -0.11370646270612876,
+ -0.09322128010292847,
+ -0.06459770910441875,
+ -0.04436610918492079,
+ -0.008244700108965239,
+ -0.00636714743450284,
+ 0.0006991980286935965,
+ 0.0036097665627797446,
+ -0.0035287897723416486,
+ -0.007594603579491377
+ ],
+ "bridge_residual": [
+ 0.06566914729773998,
+ 0.37236853316426277,
+ 0.3114783614873886,
+ 0.26258066420753795,
+ 0.20779911428689957,
+ 0.13781529137243828,
+ 0.10999641008675098,
+ 0.07969006771842639,
+ 0.06178054213523865,
+ 0.07378745886186759,
+ 0.0832565538585186
+ ]
+ },
+ "final_per_layer": {
+ "dfa_costate_cos": [
+ -0.039167389273643494,
+ -0.0378018394112587,
+ 0.005690325051546097,
+ -0.023073989897966385,
+ -0.0005057593807578087,
+ -0.014485953375697136,
+ 0.03301015496253967,
+ 0.04401148855686188,
+ 0.054177843034267426,
+ 0.03981431573629379,
+ -0.04246171563863754,
+ -0.0012151142582297325,
+ 0.01963176019489765
+ ],
+ "state_costate_cos": [
+ 0.9444395303726196,
+ 0.9453534483909607,
+ 0.9460644721984863,
+ 0.9466040134429932,
+ 0.9469730257987976,
+ 0.9472954273223877,
+ 0.9476633667945862,
+ 0.9478192925453186,
+ 0.947782576084137,
+ 0.9473406672477722,
+ 0.947346568107605,
+ 0.9471475481987,
+ 0.9466531276702881
+ ],
+ "credit_costate_cos": [
+ 0.04752141237258911,
+ 0.04377397149801254,
+ 0.04051002860069275,
+ 0.03716123104095459,
+ 0.0350164994597435,
+ 0.03194836527109146,
+ 0.02999947965145111,
+ 0.02913709171116352,
+ 0.027683185413479805,
+ 0.0277146864682436,
+ 0.027703404426574707,
+ 0.027434173971414566,
+ 0.027843166142702103
+ ],
+ "dfa_rho": [
+ -0.04803554713726044,
+ 0.001050771214067936,
+ 0.008967258036136627,
+ -0.0271889790892601,
+ 0.02336559258401394,
+ -0.018210411071777344,
+ 0.05891512706875801,
+ 0.040720634162425995,
+ 0.07478035986423492,
+ 0.04802168905735016,
+ -0.06280035525560379,
+ -0.0254659466445446
+ ],
+ "state_rho": [
+ 0.9311305284500122,
+ 0.9222633838653564,
+ 0.9287852644920349,
+ 0.9287664890289307,
+ 0.9245603084564209,
+ 0.928197979927063,
+ 0.9275168180465698,
+ 0.9290561676025391,
+ 0.9267844557762146,
+ 0.927483081817627,
+ 0.9315245151519775,
+ 0.9313452243804932
+ ],
+ "credit_rho": [
+ 0.05740518122911453,
+ 0.035541512072086334,
+ 0.002091987058520317,
+ 0.024556485936045647,
+ -0.006993812508881092,
+ 0.03284040838479996,
+ 0.012268777936697006,
+ -0.004999782890081406,
+ 0.014774687588214874,
+ -0.010628825053572655,
+ 0.05940534919500351,
+ -0.012859531678259373
+ ],
+ "dfa_nudge": [
+ 0.01363457553088665,
+ 0.013758538290858269,
+ -0.0032786596566438675,
+ 0.010209780186414719,
+ -0.0013850200921297073,
+ 0.004463233053684235,
+ -0.012735363095998764,
+ -0.01801125332713127,
+ -0.019914839416742325,
+ -0.012350432574748993,
+ 0.017156170681118965,
+ 0.002111367881298065
+ ],
+ "state_nudge": [
+ -0.34207087755203247,
+ -0.3417494297027588,
+ -0.3413676619529724,
+ -0.3414176106452942,
+ -0.3412688076496124,
+ -0.34025871753692627,
+ -0.340381920337677,
+ -0.34054237604141235,
+ -0.3402785062789917,
+ -0.34061968326568604,
+ -0.34013575315475464,
+ -0.340819776058197
+ ],
+ "credit_nudge": [
+ -0.012597911059856415,
+ -0.011231275275349617,
+ -0.009987404569983482,
+ -0.008699672296643257,
+ -0.007977090775966644,
+ -0.006919408217072487,
+ -0.006156830117106438,
+ -0.005891043692827225,
+ -0.00546320341527462,
+ -0.005353953689336777,
+ -0.005423387512564659,
+ -0.005434062331914902
+ ],
+ "bridge_residual": [
+ 0.10028190910816193,
+ 0.09995798766613007,
+ 0.09832010418176651,
+ 0.09528136253356934,
+ 0.09090456366539001,
+ 0.0843343734741211,
+ 0.07823348790407181,
+ 0.0723348930478096,
+ 0.06821676343679428,
+ 0.06773027032613754,
+ 0.06943590939044952,
+ 0.0740470215678215
+ ]
+ }
+} \ No newline at end of file
diff --git a/results/toy_lq/toy_lq_v2_seed123_lam0.1_sig0.1_tgw1.0_fm0.0.json b/results/toy_lq/toy_lq_v2_seed123_lam0.1_sig0.1_tgw1.0_fm0.0.json
new file mode 100644
index 0000000..56b3336
--- /dev/null
+++ b/results/toy_lq/toy_lq_v2_seed123_lam0.1_sig0.1_tgw1.0_fm0.0.json
@@ -0,0 +1,330 @@
+{
+ "config": {
+ "d_hidden": 64,
+ "output_dim": 10,
+ "num_layers": 12,
+ "sigma": 0.03,
+ "batch_size": 256,
+ "num_steps": 8000,
+ "lr_fb": 0.001,
+ "lam": 0.1,
+ "K": 8,
+ "ema_momentum": 0.995,
+ "sigma_bridge": 0.1,
+ "eval_every": 1000,
+ "seed": 123,
+ "gpu": 0,
+ "output_dir": "results/toy_lq",
+ "vnet_hidden": 256,
+ "vnet_layers": 3,
+ "term_grad_weight": 1.0,
+ "fm_weight": 0.0
+ },
+ "log": {
+ "steps": [
+ 1,
+ 1000,
+ 2000,
+ 3000,
+ 4000,
+ 5000,
+ 6000,
+ 7000,
+ 8000
+ ],
+ "dfa_costate_cos": [
+ 0.0061469420325011015,
+ 0.007492478704079986,
+ 0.006436596314112346,
+ 0.001648913836106658,
+ 0.005657717352733016,
+ 0.010142655577510595,
+ 0.005493079700196783,
+ 0.008208209726338586,
+ 0.002802101274331411
+ ],
+ "state_costate_cos": [
+ 0.04875442975511154,
+ 0.9345830877621969,
+ 0.9331425180037817,
+ 0.9344809154669443,
+ 0.9360056469837824,
+ 0.9349906196196874,
+ 0.9401087015867233,
+ 0.9377894500891367,
+ 0.9381228536367416
+ ],
+ "credit_costate_cos": [
+ 0.005350367398932576,
+ 0.8715064575274786,
+ 0.9082922885815302,
+ 0.9268463253974915,
+ 0.9348577807346979,
+ 0.9338823060194651,
+ 0.9398750513792038,
+ 0.9371217240889868,
+ 0.9403592944145203
+ ],
+ "dfa_rho": [
+ 0.014851124413932363,
+ 0.004832483362406492,
+ 0.005500619454930226,
+ 0.0014784028753638268,
+ 0.0024716570042073727,
+ 0.002679668522129456,
+ 0.004042171291075647,
+ 0.007841781480237842,
+ -0.003731151965136329
+ ],
+ "state_rho": [
+ 0.0525277191773057,
+ 0.9209283490975698,
+ 0.9212760378917059,
+ 0.9215241422255834,
+ 0.9252830098072687,
+ 0.9173514097929001,
+ 0.9268933484951655,
+ 0.9245105236768723,
+ 0.9267788628737131
+ ],
+ "credit_rho": [
+ 8.900166722014546e-05,
+ 0.8210234741369883,
+ 0.8804274102052053,
+ 0.9100336680809656,
+ 0.9228341629107794,
+ 0.9162226617336273,
+ 0.9262114216883978,
+ 0.9239083131154379,
+ 0.9273379941781362
+ ],
+ "dfa_nudge": [
+ -0.0020856610499322414,
+ -0.002391135785728693,
+ -0.0018826122395694256,
+ 0.0002794961134592692,
+ -0.0017084906188150246,
+ -0.0027558018919080496,
+ -0.0014085437481602032,
+ -0.0023310642379025617,
+ -0.0009137461893260479
+ ],
+ "state_nudge": [
+ -0.01762500188002984,
+ -0.31959830472866696,
+ -0.3143775438268979,
+ -0.3134472444653511,
+ -0.33237058420976,
+ -0.3183835695187251,
+ -0.32320864746967953,
+ -0.32294898976882297,
+ -0.31496328860521317
+ ],
+ "credit_nudge": [
+ -0.0004618208234508832,
+ -0.2996478999654452,
+ -0.30627985050280887,
+ -0.3099561383326848,
+ -0.33090290675560635,
+ -0.31673717498779297,
+ -0.3218600004911423,
+ -0.32163529098033905,
+ -0.31442063798507053
+ ],
+ "bridge_residual": [],
+ "state_bridge_loss": [
+ 64.66886901855469,
+ 1.9055249691009521,
+ 1.9802964925765991,
+ 1.9747117757797241,
+ 1.9333608150482178,
+ 2.0244460105895996,
+ 2.235288381576538,
+ 1.8483705520629883,
+ 1.833404541015625
+ ],
+ "credit_bridge_loss": [
+ 129.2601776123047,
+ 8.826760292053223,
+ 8.817267417907715,
+ 9.7413911819458,
+ 8.333325386047363,
+ 8.390682220458984,
+ 8.381990432739258,
+ 9.069635391235352,
+ 8.738546371459961
+ ],
+ "term_loss": [
+ 109.68403625488281,
+ 3.380112648010254,
+ 4.175432205200195,
+ 4.737654685974121,
+ 3.632157802581787,
+ 3.3351938724517822,
+ 3.776655912399292,
+ 4.416824817657471,
+ 4.120006084442139
+ ],
+ "bridge_loss": [
+ 5.943464884694549e-07,
+ 0.2433367669582367,
+ 0.18713834881782532,
+ 0.12417592853307724,
+ 0.13950209319591522,
+ 0.14192476868629456,
+ 0.15276893973350525,
+ 0.10663559287786484,
+ 0.12137635797262192
+ ],
+ "term_grad_loss": [
+ 19.57614517211914,
+ 5.203310966491699,
+ 4.4546966552734375,
+ 4.879560470581055,
+ 4.5616655349731445,
+ 4.9135637283325195,
+ 4.452565670013428,
+ 4.546175003051758,
+ 4.497163772583008
+ ],
+ "fm_loss": [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ]
+ },
+ "final_per_layer": {
+ "dfa_costate_cos": [
+ 0.04479089006781578,
+ -0.04426354169845581,
+ 0.008874599821865559,
+ 0.05599997937679291,
+ 0.02961653470993042,
+ -0.022058574482798576,
+ 0.027186769992113113,
+ -0.0337681919336319,
+ -0.020245034247636795,
+ -0.04076787084341049,
+ 0.006424968130886555,
+ 0.021834686398506165
+ ],
+ "state_costate_cos": [
+ 0.9351104497909546,
+ 0.9362065196037292,
+ 0.9370632171630859,
+ 0.9374789595603943,
+ 0.9381210803985596,
+ 0.9385455846786499,
+ 0.9386894702911377,
+ 0.9390624165534973,
+ 0.9393105506896973,
+ 0.9393242597579956,
+ 0.9392982721328735,
+ 0.9392634630203247
+ ],
+ "credit_costate_cos": [
+ 0.9360430240631104,
+ 0.9369803667068481,
+ 0.9380120038986206,
+ 0.9385954141616821,
+ 0.9395070672035217,
+ 0.9403176307678223,
+ 0.9409258961677551,
+ 0.9413450956344604,
+ 0.9420279264450073,
+ 0.9428501725196838,
+ 0.9435971975326538,
+ 0.9441097378730774
+ ],
+ "dfa_rho": [
+ 0.041250549256801605,
+ -0.049739208072423935,
+ 0.00025176629424095154,
+ 0.04237007349729538,
+ 0.040798719972372055,
+ -0.037202395498752594,
+ 0.004211327061057091,
+ 0.0017704367637634277,
+ -0.04704931750893593,
+ -0.02992381900548935,
+ -0.01541070081293583,
+ 0.003898744471371174
+ ],
+ "state_rho": [
+ 0.9230573177337646,
+ 0.9292065501213074,
+ 0.9277745485305786,
+ 0.926990270614624,
+ 0.930223822593689,
+ 0.9262816905975342,
+ 0.9238564968109131,
+ 0.9235492944717407,
+ 0.927101731300354,
+ 0.929185152053833,
+ 0.9233707189559937,
+ 0.9307487607002258
+ ],
+ "credit_rho": [
+ 0.924521803855896,
+ 0.9213208556175232,
+ 0.923781156539917,
+ 0.9222273826599121,
+ 0.9230118989944458,
+ 0.9317750930786133,
+ 0.9238618016242981,
+ 0.9351733326911926,
+ 0.9263075590133667,
+ 0.9302359819412231,
+ 0.9358923435211182,
+ 0.9299467206001282
+ ],
+ "dfa_nudge": [
+ -0.01619502529501915,
+ 0.016632290557026863,
+ -0.004560360684990883,
+ -0.018955951556563377,
+ -0.011058392003178596,
+ 0.007517071440815926,
+ -0.008148249238729477,
+ 0.01091383583843708,
+ 0.006811931729316711,
+ 0.015017258003354073,
+ -0.000978708267211914,
+ -0.00796065479516983
+ ],
+ "state_nudge": [
+ -0.31427979469299316,
+ -0.3147982358932495,
+ -0.3148774206638336,
+ -0.3150269389152527,
+ -0.31536394357681274,
+ -0.3156891465187073,
+ -0.3149397373199463,
+ -0.3145085573196411,
+ -0.3144562840461731,
+ -0.31494998931884766,
+ -0.3152415156364441,
+ -0.31542789936065674
+ ],
+ "credit_nudge": [
+ -0.31250905990600586,
+ -0.3131006360054016,
+ -0.31331712007522583,
+ -0.31366780400276184,
+ -0.3142547011375427,
+ -0.3149237632751465,
+ -0.31442567706108093,
+ -0.314262330532074,
+ -0.31445154547691345,
+ -0.3153681755065918,
+ -0.3161371052265167,
+ -0.3166297376155853
+ ]
+ }
+} \ No newline at end of file
diff --git a/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.0.json b/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.0.json
new file mode 100644
index 0000000..14050eb
--- /dev/null
+++ b/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.0.json
@@ -0,0 +1,458 @@
+{
+ "config": {
+ "d_hidden": 64,
+ "output_dim": 10,
+ "num_layers": 12,
+ "sigma": 0.03,
+ "batch_size": 256,
+ "num_steps": 8000,
+ "lr_fb": 0.001,
+ "lam": 0.1,
+ "K": 8,
+ "ema_momentum": 0.995,
+ "sigma_bridge": 0.1,
+ "eval_every": 500,
+ "seed": 42,
+ "gpu": 0,
+ "output_dir": "results/toy_lq",
+ "vnet_hidden": 256,
+ "vnet_layers": 3,
+ "term_grad_weight": 1.0,
+ "fm_weight": 0.0
+ },
+ "log": {
+ "steps": [
+ 1,
+ 500,
+ 1000,
+ 1500,
+ 2000,
+ 2500,
+ 3000,
+ 3500,
+ 4000,
+ 4500,
+ 5000,
+ 5500,
+ 6000,
+ 6500,
+ 7000,
+ 7500,
+ 8000
+ ],
+ "dfa_costate_cos": [
+ 0.001022340264171362,
+ 0.00324871806272616,
+ 0.003955680275491129,
+ -6.335701133745412e-05,
+ 0.0009047117297692845,
+ -0.0015635235855976741,
+ 0.0050255006256823736,
+ 0.003974012487257521,
+ -0.0024126653637116155,
+ 0.0010036304011009634,
+ -0.0006338046708454689,
+ 0.00017608713824301958,
+ -9.714129070440929e-05,
+ -0.0011926805212472875,
+ -0.0026168684319903455,
+ -0.003671417556082209,
+ 0.0016770829679444432
+ ],
+ "state_costate_cos": [
+ 0.009337988216429949,
+ 0.943408285578092,
+ 0.9517824550469717,
+ 0.9448912839094797,
+ 0.9432462950547537,
+ 0.9444176405668259,
+ 0.9435405482848486,
+ 0.9443136354287466,
+ 0.9458288550376892,
+ 0.9465046127637228,
+ 0.9465998113155365,
+ 0.9461832990248998,
+ 0.9424780905246735,
+ 0.9474884221951166,
+ 0.9449383119742075,
+ 0.9425351719061533,
+ 0.9405365288257599
+ ],
+ "credit_costate_cos": [
+ 0.024892715892444055,
+ 0.8234576731920242,
+ 0.8861371924479803,
+ 0.8940252065658569,
+ 0.9119344304005305,
+ 0.928387979666392,
+ 0.9261951943238577,
+ 0.9374307443698248,
+ 0.939670259753863,
+ 0.9423713783423106,
+ 0.9447299987077713,
+ 0.944857731461525,
+ 0.9409955541292826,
+ 0.9459459533294042,
+ 0.9439971546332041,
+ 0.9420813073714575,
+ 0.9400773843129476
+ ],
+ "dfa_rho": [
+ 0.015879416760678094,
+ 0.004265802912414074,
+ 0.008714484671751658,
+ 0.001407407767449816,
+ 0.007156838779337704,
+ -0.0010706717148423195,
+ -0.00500367038572828,
+ 0.00037602245962868136,
+ 0.00797194141584138,
+ 0.012475600582547486,
+ 0.006475616673318048,
+ 0.006899521841357152,
+ 0.008204833992446462,
+ -0.0002780493038396041,
+ -0.009471656677002708,
+ -0.00781721225939691,
+ 0.025706250220537186
+ ],
+ "state_rho": [
+ 0.0029325426245729127,
+ 0.9265175064404806,
+ 0.938806007305781,
+ 0.9280826350053152,
+ 0.9287765026092529,
+ 0.9266663392384847,
+ 0.9301863809426626,
+ 0.9323704888423284,
+ 0.9297124246756235,
+ 0.9302101731300354,
+ 0.9340693553288778,
+ 0.9324037134647369,
+ 0.9265123655398687,
+ 0.9321505973736445,
+ 0.927787164847056,
+ 0.9323761165142059,
+ 0.9224280416965485
+ ],
+ "credit_rho": [
+ 0.02234963719577839,
+ 0.7719902147849401,
+ 0.8432890474796295,
+ 0.8447659462690353,
+ 0.877406562368075,
+ 0.8958166440327963,
+ 0.9122767845789591,
+ 0.9254603485266367,
+ 0.9236042300860087,
+ 0.9263056516647339,
+ 0.9323819329341253,
+ 0.9314924577871958,
+ 0.9266939163208008,
+ 0.929329847296079,
+ 0.9257104198137919,
+ 0.9289217789967855,
+ 0.9269137680530548
+ ],
+ "dfa_nudge": [
+ -0.0003799900102118651,
+ -0.0009698765352368355,
+ -0.0013957968913018703,
+ 0.00035563452790180844,
+ -0.0003614289841304223,
+ 0.0008372832089662552,
+ -0.0013908503266672294,
+ -0.0013163429684937,
+ 0.0005854369762043158,
+ -0.000560786963130037,
+ 0.00015371766251822314,
+ -1.0695696497956911e-05,
+ 2.807355485856533e-05,
+ 5.036763225992521e-05,
+ 0.0008217894161740938,
+ 0.0017250357971837123,
+ -0.0004334271264572938
+ ],
+ "state_nudge": [
+ -0.002327537008871635,
+ -0.34193428109089535,
+ -0.340603639682134,
+ -0.35042013972997665,
+ -0.3389856591820717,
+ -0.34554000198841095,
+ -0.3481511374314626,
+ -0.3558509051799774,
+ -0.3337947155038516,
+ -0.33715616663297016,
+ -0.33583804468313855,
+ -0.35106155276298523,
+ -0.34434546530246735,
+ -0.34641525397698086,
+ -0.3329972525437673,
+ -0.34870391835769016,
+ -0.3395152688026428
+ ],
+ "credit_nudge": [
+ -0.0079942528779308,
+ -0.30343081553777057,
+ -0.31936119496822357,
+ -0.33327333877484006,
+ -0.3284662067890167,
+ -0.3394550507267316,
+ -0.3424832498033841,
+ -0.3524467721581459,
+ -0.33071904132763547,
+ -0.3346964443723361,
+ -0.3337800477941831,
+ -0.34932391593853634,
+ -0.34230031818151474,
+ -0.3444634775320689,
+ -0.3310972551504771,
+ -0.34694332132736844,
+ -0.3377286195755005
+ ],
+ "bridge_residual": [],
+ "state_bridge_loss": [
+ 66.01814270019531,
+ 2.0012869834899902,
+ 2.1027088165283203,
+ 2.1019272804260254,
+ 2.0727572441101074,
+ 1.9770110845565796,
+ 2.1761908531188965,
+ 2.094480514526367,
+ 1.9725749492645264,
+ 2.0142102241516113,
+ 2.0340821743011475,
+ 1.9380583763122559,
+ 1.989743947982788,
+ 2.4057328701019287,
+ 2.19437575340271,
+ 2.1816155910491943,
+ 2.0803794860839844
+ ],
+ "credit_bridge_loss": [
+ 132.09298706054688,
+ 11.307573318481445,
+ 9.575517654418945,
+ 8.768391609191895,
+ 8.752981185913086,
+ 9.271968841552734,
+ 8.606213569641113,
+ 8.753968238830566,
+ 8.223488807678223,
+ 9.935165405273438,
+ 9.02967357635498,
+ 8.346063613891602,
+ 9.022041320800781,
+ 8.36446762084961,
+ 8.635725021362305,
+ 8.760185241699219,
+ 8.503408432006836
+ ],
+ "term_loss": [
+ 111.63633728027344,
+ 4.978545188903809,
+ 3.8962953090667725,
+ 3.81073260307312,
+ 4.386394500732422,
+ 4.507748603820801,
+ 3.6740365028381348,
+ 3.7450127601623535,
+ 3.5060954093933105,
+ 4.545898914337158,
+ 4.322302341461182,
+ 3.594371795654297,
+ 4.038668632507324,
+ 4.025404453277588,
+ 3.9016177654266357,
+ 3.7638583183288574,
+ 3.807260036468506
+ ],
+ "bridge_loss": [
+ 6.45359421014291e-07,
+ 0.432157039642334,
+ 0.20619139075279236,
+ 0.24715952575206757,
+ 0.1523856669664383,
+ 0.12624874711036682,
+ 0.13425695896148682,
+ 0.16367560625076294,
+ 0.14530189335346222,
+ 0.18452416360378265,
+ 0.11599670350551605,
+ 0.11983858048915863,
+ 0.11901542544364929,
+ 0.16489851474761963,
+ 0.16058529913425446,
+ 0.09414370357990265,
+ 0.12402483820915222
+ ],
+ "term_grad_loss": [
+ 20.456655502319336,
+ 5.8968706130981445,
+ 5.4730305671691895,
+ 4.710499286651611,
+ 4.214200496673584,
+ 4.637970924377441,
+ 4.797920227050781,
+ 4.845280170440674,
+ 4.572091579437256,
+ 5.204742908477783,
+ 4.591374397277832,
+ 4.631853103637695,
+ 4.864356994628906,
+ 4.174164772033691,
+ 4.573522090911865,
+ 4.9021830558776855,
+ 4.5721235275268555
+ ],
+ "fm_loss": [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ]
+ },
+ "final_per_layer": {
+ "dfa_costate_cos": [
+ -0.03974368795752525,
+ -0.03234883397817612,
+ -0.01280729565769434,
+ -0.03282422572374344,
+ 0.005050511099398136,
+ -0.01393066719174385,
+ 0.036346256732940674,
+ 0.04377925395965576,
+ 0.06324614584445953,
+ 0.05324942618608475,
+ -0.04297419637441635,
+ -0.006917691323906183
+ ],
+ "state_costate_cos": [
+ 0.9383341073989868,
+ 0.9390445947647095,
+ 0.9393846988677979,
+ 0.9396730661392212,
+ 0.9402315616607666,
+ 0.940619707107544,
+ 0.9413514137268066,
+ 0.9420933723449707,
+ 0.941848874092102,
+ 0.9416857957839966,
+ 0.9411274194717407,
+ 0.9410437345504761
+ ],
+ "credit_costate_cos": [
+ 0.937440037727356,
+ 0.9379364252090454,
+ 0.9384041428565979,
+ 0.9385387897491455,
+ 0.9392049312591553,
+ 0.9398118257522583,
+ 0.9403334259986877,
+ 0.9412999153137207,
+ 0.94156414270401,
+ 0.9419326186180115,
+ 0.9419490098953247,
+ 0.9425133466720581
+ ],
+ "dfa_rho": [
+ -0.023802103474736214,
+ -0.039411380887031555,
+ -0.008970402181148529,
+ -0.0021191751584410667,
+ 0.04573667049407959,
+ 0.010564171709120274,
+ 0.04995737224817276,
+ 0.10094872862100601,
+ 0.07801118493080139,
+ 0.07188688963651657,
+ -0.02006501331925392,
+ 0.045738060027360916
+ ],
+ "state_rho": [
+ 0.9204449653625488,
+ 0.917449951171875,
+ 0.9193954467773438,
+ 0.9267550706863403,
+ 0.9280773401260376,
+ 0.9225738048553467,
+ 0.9190236926078796,
+ 0.924782931804657,
+ 0.9248216152191162,
+ 0.9217315912246704,
+ 0.9215092658996582,
+ 0.9225708246231079
+ ],
+ "credit_rho": [
+ 0.9222818613052368,
+ 0.923028826713562,
+ 0.924055814743042,
+ 0.9263180494308472,
+ 0.9288915991783142,
+ 0.922785758972168,
+ 0.9231191873550415,
+ 0.9254114627838135,
+ 0.9321990013122559,
+ 0.9311563372612,
+ 0.934211015701294,
+ 0.9295063018798828
+ ],
+ "dfa_nudge": [
+ 0.014356113970279694,
+ 0.013086749240756035,
+ 0.004036703146994114,
+ 0.011657552793622017,
+ -0.0023863790556788445,
+ 0.006091207265853882,
+ -0.013204541988670826,
+ -0.016102034598588943,
+ -0.022989045828580856,
+ -0.018843289464712143,
+ 0.015521062538027763,
+ 0.0035747764632105827
+ ],
+ "state_nudge": [
+ -0.34063079953193665,
+ -0.34022200107574463,
+ -0.33977580070495605,
+ -0.3401387929916382,
+ -0.34019535779953003,
+ -0.33964985609054565,
+ -0.3392874002456665,
+ -0.33898109197616577,
+ -0.3392573297023773,
+ -0.3390581011772156,
+ -0.33809930086135864,
+ -0.33888739347457886
+ ],
+ "credit_nudge": [
+ -0.3375471234321594,
+ -0.33736705780029297,
+ -0.33710652589797974,
+ -0.33756691217422485,
+ -0.3378610908985138,
+ -0.33758991956710815,
+ -0.33741605281829834,
+ -0.3374325633049011,
+ -0.33806926012039185,
+ -0.33822208642959595,
+ -0.33765909075737,
+ -0.3389057517051697
+ ]
+ }
+} \ No newline at end of file
diff --git a/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.1.json b/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.1.json
new file mode 100644
index 0000000..2977e98
--- /dev/null
+++ b/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.1.json
@@ -0,0 +1,330 @@
+{
+ "config": {
+ "d_hidden": 64,
+ "output_dim": 10,
+ "num_layers": 12,
+ "sigma": 0.03,
+ "batch_size": 256,
+ "num_steps": 8000,
+ "lr_fb": 0.001,
+ "lam": 0.1,
+ "K": 8,
+ "ema_momentum": 0.995,
+ "sigma_bridge": 0.1,
+ "eval_every": 1000,
+ "seed": 42,
+ "gpu": 0,
+ "output_dir": "results/toy_lq",
+ "vnet_hidden": 256,
+ "vnet_layers": 3,
+ "term_grad_weight": 1.0,
+ "fm_weight": 0.1
+ },
+ "log": {
+ "steps": [
+ 1,
+ 1000,
+ 2000,
+ 3000,
+ 4000,
+ 5000,
+ 6000,
+ 7000,
+ 8000
+ ],
+ "dfa_costate_cos": [
+ -0.003210080186060319,
+ 0.006517285481095314,
+ 0.002081584728633364,
+ 0.002482607301014165,
+ 0.0003743169557613631,
+ -0.0018058015266433358,
+ 0.005970992535973589,
+ 0.0013112322388527293,
+ -0.00010686229992037018
+ ],
+ "state_costate_cos": [
+ -0.008148168989767631,
+ 0.942050834496816,
+ 0.9447367936372757,
+ 0.9457606921593348,
+ 0.9468860725561777,
+ 0.9432125687599182,
+ 0.9417577485243479,
+ 0.945394163330396,
+ 0.9473433345556259
+ ],
+ "credit_costate_cos": [
+ -0.011870964197441936,
+ 0.8804089923699697,
+ 0.9169842700163523,
+ 0.9352772484223048,
+ 0.9427581528822581,
+ 0.9391989062229792,
+ 0.9398341725269953,
+ 0.9456242968638738,
+ 0.9460541109244028
+ ],
+ "dfa_rho": [
+ -0.012772266054525971,
+ 0.015942092907304566,
+ 0.008943260026474794,
+ 0.004492536109561722,
+ -0.007791806012392044,
+ -0.007830069400370121,
+ 0.014392409706488252,
+ 0.0063066319562494755,
+ -0.001311147507900993
+ ],
+ "state_rho": [
+ -0.0018634579222028453,
+ 0.9263874938090643,
+ 0.9358506997426351,
+ 0.9348766555388769,
+ 0.9348586251338323,
+ 0.9323162386814753,
+ 0.9287882298231125,
+ 0.9243461688359579,
+ 0.9346484492222468
+ ],
+ "credit_rho": [
+ -0.023325924955618877,
+ 0.8106876164674759,
+ 0.8797721316417059,
+ 0.9208102275927862,
+ 0.9290666033824285,
+ 0.9226995905240377,
+ 0.9258128056923548,
+ 0.9266915867726008,
+ 0.9364602218071619
+ ],
+ "dfa_nudge": [
+ 0.0015840742271393538,
+ -0.0023392424918711185,
+ -0.0002702907659113407,
+ -0.0010186488119264443,
+ -0.0001735797462364038,
+ 0.000983052421361208,
+ -0.0020705487113445997,
+ -0.0009097627674539884,
+ 0.0002721068449318409
+ ],
+ "state_nudge": [
+ 0.0024757004963854947,
+ -0.33942782630523044,
+ -0.35110870252052945,
+ -0.35064691056807834,
+ -0.3403419057528178,
+ -0.3532385254899661,
+ -0.33996084084113437,
+ -0.35185040285189945,
+ -0.3480343023935954
+ ],
+ "credit_nudge": [
+ 0.006069206205817561,
+ -0.31947339574495953,
+ -0.3404841795563698,
+ -0.3458903282880783,
+ -0.3376796667774518,
+ -0.3504715636372566,
+ -0.3379351521531741,
+ -0.3500200683871905,
+ -0.3462969238559405
+ ],
+ "bridge_residual": [],
+ "state_bridge_loss": [
+ 66.01814270019531,
+ 2.341611623764038,
+ 1.9537389278411865,
+ 1.9830116033554077,
+ 2.0491604804992676,
+ 2.0490386486053467,
+ 2.1182405948638916,
+ 2.373213291168213,
+ 1.9122812747955322
+ ],
+ "credit_bridge_loss": [
+ 132.09298706054688,
+ 9.123503684997559,
+ 8.516526222229004,
+ 8.634014129638672,
+ 8.720410346984863,
+ 8.43734359741211,
+ 10.247673034667969,
+ 8.351385116577148,
+ 8.474419593811035
+ ],
+ "term_loss": [
+ 111.63633728027344,
+ 4.155745029449463,
+ 3.8754897117614746,
+ 4.040826320648193,
+ 4.010752201080322,
+ 4.2646074295043945,
+ 5.391899108886719,
+ 3.27858829498291,
+ 3.74959135055542
+ ],
+ "bridge_loss": [
+ 6.45359421014291e-07,
+ 0.22446846961975098,
+ 0.18948684632778168,
+ 0.14564603567123413,
+ 0.1301368772983551,
+ 0.1614246517419815,
+ 0.19166265428066254,
+ 0.19554881751537323,
+ 0.15442883968353271
+ ],
+ "term_grad_loss": [
+ 20.456655502319336,
+ 4.743132591247559,
+ 4.45129919052124,
+ 4.447327136993408,
+ 4.579296112060547,
+ 4.011109352111816,
+ 4.663890838623047,
+ 4.877087593078613,
+ 4.57021427154541
+ ],
+ "fm_loss": [
+ 1.4957863925246784e-07,
+ 0.001574160298332572,
+ 0.0025115222670137882,
+ 0.002142899436876178,
+ 0.0022483705542981625,
+ 0.0020201210863888264,
+ 0.00219835271127522,
+ 0.0016042720526456833,
+ 0.0018535356502979994
+ ]
+ },
+ "final_per_layer": {
+ "dfa_costate_cos": [
+ -0.04993891716003418,
+ -0.03484980762004852,
+ 0.008264334872364998,
+ -0.018589604645967484,
+ -0.019490346312522888,
+ 0.0018934037070721388,
+ 0.025425557047128677,
+ 0.046417269855737686,
+ 0.06304077804088593,
+ 0.03301604464650154,
+ -0.06211906671524048,
+ 0.005648006685078144
+ ],
+ "state_costate_cos": [
+ 0.9454483985900879,
+ 0.946227490901947,
+ 0.9468960165977478,
+ 0.9470230340957642,
+ 0.9475629925727844,
+ 0.9478208422660828,
+ 0.9478966593742371,
+ 0.9480605125427246,
+ 0.9479074478149414,
+ 0.9478251934051514,
+ 0.9478766918182373,
+ 0.9475747346878052
+ ],
+ "credit_costate_cos": [
+ 0.9432812333106995,
+ 0.9438751935958862,
+ 0.9443729519844055,
+ 0.9447157382965088,
+ 0.9455237984657288,
+ 0.9461240768432617,
+ 0.9463104009628296,
+ 0.9467931985855103,
+ 0.9471821784973145,
+ 0.9476308226585388,
+ 0.9481172561645508,
+ 0.9487224817276001
+ ],
+ "dfa_rho": [
+ -0.0704549178481102,
+ 0.031227122992277145,
+ 0.009932642802596092,
+ -0.015737976878881454,
+ -0.045219600200653076,
+ 0.013525029644370079,
+ 0.03365146368741989,
+ 0.04508150368928909,
+ 0.05846859887242317,
+ 7.291417568922043e-05,
+ -0.06866942346096039,
+ -0.007611127570271492
+ ],
+ "state_rho": [
+ 0.935008704662323,
+ 0.9310771822929382,
+ 0.9295646548271179,
+ 0.938683032989502,
+ 0.9329910278320312,
+ 0.9313007593154907,
+ 0.9357653260231018,
+ 0.9376221895217896,
+ 0.9326146841049194,
+ 0.9372754693031311,
+ 0.9379282593727112,
+ 0.9359501004219055
+ ],
+ "credit_rho": [
+ 0.9306489825248718,
+ 0.9282867312431335,
+ 0.9364583492279053,
+ 0.9361423850059509,
+ 0.9323122501373291,
+ 0.9380875825881958,
+ 0.93974369764328,
+ 0.9351372122764587,
+ 0.941116213798523,
+ 0.9382349252700806,
+ 0.9381735324859619,
+ 0.9431807994842529
+ ],
+ "dfa_nudge": [
+ 0.020365234464406967,
+ 0.014575351029634476,
+ -0.002057630568742752,
+ 0.006316322833299637,
+ 0.007856002077460289,
+ 0.0004078727215528488,
+ -0.010462434962391853,
+ -0.017474167048931122,
+ -0.02493620105087757,
+ -0.011619189754128456,
+ 0.022399690002202988,
+ -0.00210556760430336
+ ],
+ "state_nudge": [
+ -0.34937936067581177,
+ -0.3488026261329651,
+ -0.3481972813606262,
+ -0.34827157855033875,
+ -0.3484804630279541,
+ -0.3478168845176697,
+ -0.34806668758392334,
+ -0.3481366038322449,
+ -0.3481329381465912,
+ -0.3476967513561249,
+ -0.3463967740535736,
+ -0.34703367948532104
+ ],
+ "credit_nudge": [
+ -0.34637150168418884,
+ -0.345956414937973,
+ -0.3455013632774353,
+ -0.34578341245651245,
+ -0.3462638854980469,
+ -0.3458409905433655,
+ -0.3462728261947632,
+ -0.3466273546218872,
+ -0.3469495475292206,
+ -0.34689074754714966,
+ -0.3459630608558655,
+ -0.34714198112487793
+ ]
+ }
+} \ No newline at end of file
diff --git a/results/toy_lq/toy_lq_v2_seed42_lam1.0_sig0.3_tgw0.0_fm0.0.json b/results/toy_lq/toy_lq_v2_seed42_lam1.0_sig0.3_tgw0.0_fm0.0.json
new file mode 100644
index 0000000..78451b9
--- /dev/null
+++ b/results/toy_lq/toy_lq_v2_seed42_lam1.0_sig0.3_tgw0.0_fm0.0.json
@@ -0,0 +1,282 @@
+{
+ "config": {
+ "d_hidden": 64,
+ "output_dim": 10,
+ "num_layers": 12,
+ "sigma": 0.03,
+ "batch_size": 256,
+ "num_steps": 5000,
+ "lr_fb": 0.001,
+ "lam": 1.0,
+ "K": 8,
+ "ema_momentum": 0.995,
+ "sigma_bridge": 0.3,
+ "eval_every": 1000,
+ "seed": 42,
+ "gpu": 0,
+ "output_dir": "results/toy_lq",
+ "vnet_hidden": 256,
+ "vnet_layers": 3,
+ "term_grad_weight": 0.0,
+ "fm_weight": 0.0
+ },
+ "log": {
+ "steps": [
+ 1,
+ 1000,
+ 2000,
+ 3000,
+ 4000,
+ 5000
+ ],
+ "dfa_costate_cos": [
+ 0.001022340264171362,
+ 0.0024294707691296935,
+ 0.000613357910575966,
+ 0.002641987521201372,
+ 0.0019003628791930776,
+ 0.004648004221962765
+ ],
+ "state_costate_cos": [
+ 0.009337988216429949,
+ 0.9437081764141718,
+ 0.9386202742656072,
+ 0.9438605606555939,
+ 0.9475679496924082,
+ 0.9496726940075556
+ ],
+ "credit_costate_cos": [
+ 0.023481125943362713,
+ 0.0021192015459140143,
+ 0.05908452393487096,
+ 0.11570050132771333,
+ 0.09644180163741112,
+ 0.06304322431484859
+ ],
+ "dfa_rho": [
+ 0.015879416760678094,
+ 0.009290086299491426,
+ 0.0009949249991526206,
+ -0.004670841522359599,
+ -0.0029721508423487344,
+ 0.0010571565168599288
+ ],
+ "state_rho": [
+ 0.0029325426245729127,
+ 0.9371241927146912,
+ 0.9219773809115092,
+ 0.9298903445402781,
+ 0.9345368842283884,
+ 0.9332165767749151
+ ],
+ "credit_rho": [
+ 0.02086908878603329,
+ -0.014479975526531538,
+ 0.04267269264285763,
+ 0.10674913817395766,
+ 0.091526560485363,
+ 0.04765695089008659
+ ],
+ "dfa_nudge": [
+ -0.0003799900102118651,
+ -0.0008909914953013262,
+ 0.00031573620314399403,
+ -0.0008827850688248873,
+ -0.0003006396194299062,
+ -0.0016971436173965533
+ ],
+ "state_nudge": [
+ -0.002327537008871635,
+ -0.33619146794080734,
+ -0.3439306889971097,
+ -0.32351043323675793,
+ -0.33487510432799655,
+ -0.35304194688796997
+ ],
+ "credit_nudge": [
+ -0.007470574385176103,
+ 0.002750888311614593,
+ -0.017584003585701186,
+ -0.037241545505821705,
+ -0.03454847944279512,
+ -0.02138534157226483
+ ],
+ "bridge_residual": [],
+ "state_bridge_loss": [
+ 66.01814270019531,
+ 2.143366813659668,
+ 1.9674744606018066,
+ 2.152421712875366,
+ 2.0728020668029785,
+ 2.0048255920410156
+ ],
+ "credit_bridge_loss": [
+ 111.63633728027344,
+ 0.35802197456359863,
+ 0.10368431359529495,
+ 0.12013768404722214,
+ 0.07032017409801483,
+ 0.051539346575737
+ ],
+ "term_loss": [
+ 111.63633728027344,
+ 0.2640027403831482,
+ 0.05911973863840103,
+ 0.06323631852865219,
+ 0.03984691575169563,
+ 0.03099522553384304
+ ],
+ "bridge_loss": [
+ 3.0909704946679994e-06,
+ 0.09401924908161163,
+ 0.04456457495689392,
+ 0.056901365518569946,
+ 0.030473260208964348,
+ 0.02054412104189396
+ ],
+ "term_grad_loss": [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "fm_loss": [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ]
+ },
+ "final_per_layer": {
+ "dfa_costate_cos": [
+ -0.025100916624069214,
+ -0.025588383898139,
+ 0.0002230944810435176,
+ -0.012649456970393658,
+ -0.003438096959143877,
+ -0.008626529946923256,
+ 0.04746557027101517,
+ 0.0495537593960762,
+ 0.05357586592435837,
+ 0.039518602192401886,
+ -0.06593793630599976,
+ 0.0067804791033267975
+ ],
+ "state_costate_cos": [
+ 0.94776451587677,
+ 0.9486032724380493,
+ 0.9489368200302124,
+ 0.9489387273788452,
+ 0.9493762850761414,
+ 0.9498711228370667,
+ 0.9505667686462402,
+ 0.9509860277175903,
+ 0.9503088593482971,
+ 0.9503939151763916,
+ 0.9501732587814331,
+ 0.9501527547836304
+ ],
+ "credit_costate_cos": [
+ 0.06872019171714783,
+ 0.0653301253914833,
+ 0.06267654895782471,
+ 0.060872238129377365,
+ 0.059321627020835876,
+ 0.05934217572212219,
+ 0.06014599651098251,
+ 0.05938819795846939,
+ 0.06065107509493828,
+ 0.06321083009243011,
+ 0.07017489522695541,
+ 0.066684789955616
+ ],
+ "dfa_rho": [
+ -0.039396628737449646,
+ -0.062005721032619476,
+ 0.005523890256881714,
+ 0.014756813645362854,
+ -0.029910940676927567,
+ 0.016373179852962494,
+ 0.06027424708008766,
+ 0.07274787873029709,
+ 0.04506715014576912,
+ 0.02043943479657173,
+ -0.07323547452688217,
+ -0.017947951331734657
+ ],
+ "state_rho": [
+ 0.9350653886795044,
+ 0.9352835416793823,
+ 0.9357374310493469,
+ 0.934054434299469,
+ 0.9226828813552856,
+ 0.930233359336853,
+ 0.9396862983703613,
+ 0.9324671030044556,
+ 0.9295884966850281,
+ 0.9347179532051086,
+ 0.9344969987869263,
+ 0.9345850348472595
+ ],
+ "credit_rho": [
+ 0.07084883749485016,
+ 0.0004139472730457783,
+ 0.08371011912822723,
+ 0.010933063924312592,
+ 0.07074157148599625,
+ 0.02977888286113739,
+ 0.06011004000902176,
+ 0.020937703549861908,
+ 0.067134790122509,
+ 0.04539356008172035,
+ 0.06114630401134491,
+ 0.050734590739011765
+ ],
+ "dfa_nudge": [
+ 0.010068551637232304,
+ 0.009215200319886208,
+ -0.002217007800936699,
+ 0.007142472080886364,
+ 0.0020765019580721855,
+ 0.004165132530033588,
+ -0.019989464432001114,
+ -0.018945707008242607,
+ -0.01947549544274807,
+ -0.014030318707227707,
+ 0.023684613406658173,
+ -0.0020602019503712654
+ ],
+ "state_nudge": [
+ -0.3548189401626587,
+ -0.3539973497390747,
+ -0.3535424768924713,
+ -0.35353928804397583,
+ -0.35366636514663696,
+ -0.35295385122299194,
+ -0.35282424092292786,
+ -0.35278239846229553,
+ -0.3528212308883667,
+ -0.3526480793952942,
+ -0.35135188698768616,
+ -0.35155725479125977
+ ],
+ "credit_nudge": [
+ -0.024559948593378067,
+ -0.023154649883508682,
+ -0.02181980386376381,
+ -0.020881079137325287,
+ -0.02010848931968212,
+ -0.019884146749973297,
+ -0.020017635077238083,
+ -0.01966693066060543,
+ -0.020179908722639084,
+ -0.02114756405353546,
+ -0.023488853126764297,
+ -0.021715089678764343
+ ]
+ }
+} \ No newline at end of file
diff --git a/results/toy_lq/toy_lq_v2_seed456_lam0.1_sig0.1_tgw1.0_fm0.0.json b/results/toy_lq/toy_lq_v2_seed456_lam0.1_sig0.1_tgw1.0_fm0.0.json
new file mode 100644
index 0000000..8e7c7b0
--- /dev/null
+++ b/results/toy_lq/toy_lq_v2_seed456_lam0.1_sig0.1_tgw1.0_fm0.0.json
@@ -0,0 +1,330 @@
+{
+ "config": {
+ "d_hidden": 64,
+ "output_dim": 10,
+ "num_layers": 12,
+ "sigma": 0.03,
+ "batch_size": 256,
+ "num_steps": 8000,
+ "lr_fb": 0.001,
+ "lam": 0.1,
+ "K": 8,
+ "ema_momentum": 0.995,
+ "sigma_bridge": 0.1,
+ "eval_every": 1000,
+ "seed": 456,
+ "gpu": 0,
+ "output_dir": "results/toy_lq",
+ "vnet_hidden": 256,
+ "vnet_layers": 3,
+ "term_grad_weight": 1.0,
+ "fm_weight": 0.0
+ },
+ "log": {
+ "steps": [
+ 1,
+ 1000,
+ 2000,
+ 3000,
+ 4000,
+ 5000,
+ 6000,
+ 7000,
+ 8000
+ ],
+ "dfa_costate_cos": [
+ -0.008305357536301017,
+ -0.0011448257913192113,
+ -0.011490714969113469,
+ -0.005118173896335065,
+ -0.0028971168600643673,
+ -0.007146042305976152,
+ -0.005135039333254099,
+ -0.006408803591815134,
+ 0.0030692683843274913
+ ],
+ "state_costate_cos": [
+ 0.010766413528472185,
+ 0.9454400340716044,
+ 0.9422353406747183,
+ 0.945327232281367,
+ 0.9411078443129858,
+ 0.9500264773766199,
+ 0.947187085946401,
+ 0.9483370830615362,
+ 0.9451590329408646
+ ],
+ "credit_costate_cos": [
+ 0.010942678588132063,
+ 0.8771380881468455,
+ 0.9226242254177729,
+ 0.9362819741169611,
+ 0.9372822741667429,
+ 0.9464249561230341,
+ 0.9468309829632441,
+ 0.9452949613332748,
+ 0.9439582029978434
+ ],
+ "dfa_rho": [
+ -0.0028248391657446823,
+ -0.001316564545656244,
+ -0.00823917348558704,
+ -0.008288809360237792,
+ -0.009200356512640914,
+ -0.02464000484906137,
+ 0.006656331475824118,
+ -0.014378171879798174,
+ 0.00816792449525868
+ ],
+ "state_rho": [
+ 0.02655455912463367,
+ 0.9340365380048752,
+ 0.930340642730395,
+ 0.9328931520382563,
+ 0.9261527856191,
+ 0.9353549828131994,
+ 0.9365303913752238,
+ 0.930912658572197,
+ 0.9325708796580633
+ ],
+ "credit_rho": [
+ 0.015292729716748,
+ 0.8341556191444397,
+ 0.884218767285347,
+ 0.9175956894954046,
+ 0.9214809189240137,
+ 0.9294924139976501,
+ 0.9364803830782572,
+ 0.9272431135177612,
+ 0.9332626809676489
+ ],
+ "dfa_nudge": [
+ 0.004384364855165283,
+ 0.0014555706487347682,
+ 0.005660574107120435,
+ 0.003231095770994822,
+ 0.0024609332904219627,
+ 0.0033611954810718694,
+ 0.0025722047624488673,
+ 0.003350513521581888,
+ -0.00046457063096265
+ ],
+ "state_nudge": [
+ -0.005116054400180777,
+ -0.36389906456073123,
+ -0.37967536846796673,
+ -0.3594902977347374,
+ -0.3901909242073695,
+ -0.3562925284107526,
+ -0.34591514120499295,
+ -0.3550695503751437,
+ -0.3515613650282224
+ ],
+ "credit_nudge": [
+ -0.003232262640570601,
+ -0.34101804345846176,
+ -0.37139702836672467,
+ -0.355108546713988,
+ -0.38709117472171783,
+ -0.3536044582724571,
+ -0.34420712540547055,
+ -0.35209985077381134,
+ -0.3494671831528346
+ ],
+ "bridge_residual": [],
+ "state_bridge_loss": [
+ 66.31975555419922,
+ 2.3091068267822266,
+ 2.0314407348632812,
+ 2.110802412033081,
+ 1.9804980754852295,
+ 1.8128150701522827,
+ 2.0147881507873535,
+ 2.169416904449463,
+ 2.0117833614349365
+ ],
+ "credit_bridge_loss": [
+ 158.73072814941406,
+ 11.18570613861084,
+ 9.38658332824707,
+ 9.429727554321289,
+ 10.842954635620117,
+ 10.344818115234375,
+ 10.250753402709961,
+ 10.136574745178223,
+ 9.871820449829102
+ ],
+ "term_loss": [
+ 132.93673706054688,
+ 4.870186805725098,
+ 3.9316887855529785,
+ 4.335302829742432,
+ 5.871437072753906,
+ 4.5994768142700195,
+ 5.077899932861328,
+ 4.8271284103393555,
+ 4.638625621795654
+ ],
+ "bridge_loss": [
+ 7.166463547036983e-07,
+ 0.3530547022819519,
+ 0.20790576934814453,
+ 0.15763649344444275,
+ 0.18163490295410156,
+ 0.154127299785614,
+ 0.1578863561153412,
+ 0.13594242930412292,
+ 0.12283627688884735
+ ],
+ "term_grad_loss": [
+ 25.793991088867188,
+ 5.962464809417725,
+ 5.246988296508789,
+ 4.936788082122803,
+ 4.789882183074951,
+ 5.591214656829834,
+ 5.0149664878845215,
+ 5.173503875732422,
+ 5.110358238220215
+ ],
+ "fm_loss": [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ]
+ },
+ "final_per_layer": {
+ "dfa_costate_cos": [
+ 0.0057477825321257114,
+ -0.05385865271091461,
+ 0.031865090131759644,
+ -0.0450170561671257,
+ 0.01222146674990654,
+ 0.048112209886312485,
+ -0.0029807062819600105,
+ 0.02491983398795128,
+ -0.03023526817560196,
+ -0.0383906215429306,
+ 0.005803969223052263,
+ 0.07864317297935486
+ ],
+ "state_costate_cos": [
+ 0.9431054592132568,
+ 0.9437580108642578,
+ 0.9442368149757385,
+ 0.9451977014541626,
+ 0.9453786611557007,
+ 0.9453213214874268,
+ 0.9453005790710449,
+ 0.945610761642456,
+ 0.9459618330001831,
+ 0.946029007434845,
+ 0.9460264444351196,
+ 0.9459818005561829
+ ],
+ "credit_costate_cos": [
+ 0.9404189586639404,
+ 0.9415367841720581,
+ 0.9419513940811157,
+ 0.9428710341453552,
+ 0.943200409412384,
+ 0.9437626004219055,
+ 0.9443795680999756,
+ 0.9448325037956238,
+ 0.945106029510498,
+ 0.9459753036499023,
+ 0.9465078115463257,
+ 0.9469560384750366
+ ],
+ "dfa_rho": [
+ 0.04672882705926895,
+ -0.030157513916492462,
+ 0.03139471262693405,
+ -0.02426629513502121,
+ -0.003147948533296585,
+ 0.06672847270965576,
+ -0.0028628872241824865,
+ -0.0006099608726799488,
+ -0.06767985224723816,
+ -0.07952223718166351,
+ 0.023338939994573593,
+ 0.13807083666324615
+ ],
+ "state_rho": [
+ 0.9322999715805054,
+ 0.9332318902015686,
+ 0.933842658996582,
+ 0.9367775917053223,
+ 0.9355518221855164,
+ 0.9306066036224365,
+ 0.9324768781661987,
+ 0.9316605925559998,
+ 0.9233090281486511,
+ 0.9335967302322388,
+ 0.9341922998428345,
+ 0.933304488658905
+ ],
+ "credit_rho": [
+ 0.929013192653656,
+ 0.925915002822876,
+ 0.9249007701873779,
+ 0.9293269515037537,
+ 0.9353340268135071,
+ 0.9355732202529907,
+ 0.9336704611778259,
+ 0.9395359754562378,
+ 0.934350848197937,
+ 0.9350711107254028,
+ 0.9423503279685974,
+ 0.9341102838516235
+ ],
+ "dfa_nudge": [
+ -0.0025230227038264275,
+ 0.01772259920835495,
+ -0.011610012501478195,
+ 0.01727830246090889,
+ -0.002929902635514736,
+ -0.01679525338113308,
+ 0.0033558662980794907,
+ -0.007272847928106785,
+ 0.013050400651991367,
+ 0.014434966258704662,
+ -0.0013312064111232758,
+ -0.02895473688840866
+ ],
+ "state_nudge": [
+ -0.3518102467060089,
+ -0.3520665764808655,
+ -0.35131847858428955,
+ -0.35165226459503174,
+ -0.35131698846817017,
+ -0.35155874490737915,
+ -0.3516783118247986,
+ -0.3522875905036926,
+ -0.35148167610168457,
+ -0.3511171340942383,
+ -0.35132837295532227,
+ -0.3511199951171875
+ ],
+ "credit_nudge": [
+ -0.34800052642822266,
+ -0.3485172390937805,
+ -0.3479737639427185,
+ -0.3486989140510559,
+ -0.3486787676811218,
+ -0.3492557406425476,
+ -0.3497552275657654,
+ -0.35073035955429077,
+ -0.35016047954559326,
+ -0.3501723110675812,
+ -0.3507159650325775,
+ -0.35094690322875977
+ ]
+ }
+} \ No newline at end of file
diff --git a/results/toy_lq/value_net_seed42.pt b/results/toy_lq/value_net_seed42.pt
new file mode 100644
index 0000000..0cb6683
--- /dev/null
+++ b/results/toy_lq/value_net_seed42.pt
Binary files differ